| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Optional |
| from transformers import AutoTokenizer |
| import uvicorn |
| import os |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| |
| app = FastAPI(title="Qwen3 Tokenizer API", version="1.0") |
|
|
| |
| |
| MODEL_PATH = os.getenv("Tokenizer_MODEL_PATH", "") |
| print("MODEL_PATH :",MODEL_PATH ) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=False) |
|
|
| |
| class EncodeRequest(BaseModel): |
| text: str |
| add_special_tokens: bool = True |
|
|
| class DecodeRequest(BaseModel): |
| token_ids: List[int] |
| skip_special_tokens: bool = True |
|
|
| class BatchEncodeRequest(BaseModel): |
| texts: List[str] |
| padding: bool = False |
| truncation: bool = True |
| max_length: Optional[int] = None |
| add_special_tokens: bool = True |
|
|
| |
|
|
| @app.get("/health") |
| async def health_check(): |
| """健康检查接口""" |
| return {"status": "running", "model": tokenizer.name_or_path} |
|
|
| @app.post("/encode") |
| async def encode_text(request: EncodeRequest): |
| """ |
| 将文本转换为 Token IDs |
| """ |
| try: |
| |
| encoding = tokenizer( |
| request.text, |
| add_special_tokens=request.add_special_tokens, |
| return_tensors="pt" |
| ) |
| |
| token_ids = encoding.input_ids[0].tolist() |
| |
| return { |
| "text": request.text, |
| "token_ids": token_ids, |
| "count": len(token_ids), |
| "special_tokens_added": request.add_special_tokens |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/batch_encode") |
| async def batch_encode(request: BatchEncodeRequest): |
| """ |
| 批量将文本列表转换为 Token IDs |
| """ |
| try: |
| |
| encoding = tokenizer( |
| request.texts, |
| padding=request.padding, |
| truncation=request.truncation, |
| max_length=request.max_length, |
| add_special_tokens=request.add_special_tokens, |
| return_tensors="pt" |
| ) |
| |
| |
| |
| batch_token_ids = encoding.input_ids.tolist() |
| |
| |
| lengths = [len(tokens) for tokens in batch_token_ids] |
| |
| return { |
| "batch_size": len(request.texts), |
| "token_ids_batch": batch_token_ids, |
| "lengths": lengths |
| } |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/decode") |
| async def decode_tokens(request: DecodeRequest): |
| """ |
| 将 Token IDs 还原为文本 |
| """ |
| try: |
| text = tokenizer.decode( |
| request.token_ids, |
| skip_special_tokens=request.skip_special_tokens |
| ) |
| return { |
| "token_ids": request.token_ids, |
| "text": text, |
| "count": len(request.token_ids) |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| if __name__ == "__main__": |
| |
| import uvicorn |
| port = int(os.getenv("Tokenizer_API_PORT", 8007)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|