File size: 3,931 Bytes
9cf08e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from transformers import AutoTokenizer
import uvicorn
import os
from dotenv import load_dotenv
load_dotenv()
# 1. 初始化 FastAPI
app = FastAPI(title="Qwen3 Tokenizer API", version="1.0")
# 2. 全局加载 Tokenizer (避免每次请求都加载)
# 如果是本地文件,请替换为本地路径,例如 "./qwen3_tokenizer"
MODEL_PATH = os.getenv("Tokenizer_MODEL_PATH", "")
print("MODEL_PATH :",MODEL_PATH )
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=False)
# 3. 定义请求数据结构
class EncodeRequest(BaseModel):
text: str
add_special_tokens: bool = True # 是否添加 <\|im_start\|> 等特殊标记
class DecodeRequest(BaseModel):
token_ids: List[int]
skip_special_tokens: bool = True # 是否跳过特殊标记
class BatchEncodeRequest(BaseModel):
texts: List[str] # 接收文本列表
padding: bool = False # 是否自动填充到最长序列
truncation: bool = True # 是否截断超长文本
max_length: Optional[int] = None # 最大长度限制
add_special_tokens: bool = True
# 4. 定义 API 路由
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "running", "model": tokenizer.name_or_path}
@app.post("/encode")
async def encode_text(request: EncodeRequest):
"""
将文本转换为 Token IDs
"""
try:
# 调用 tokenizer
encoding = tokenizer(
request.text,
add_special_tokens=request.add_special_tokens,
return_tensors="pt" # 返回 PyTorch 张量,也可以选 "np"
)
token_ids = encoding.input_ids[0].tolist()
return {
"text": request.text,
"token_ids": token_ids,
"count": len(token_ids), # 返回长度,方便前端计算
"special_tokens_added": request.add_special_tokens
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_encode")
async def batch_encode(request: BatchEncodeRequest):
"""
批量将文本列表转换为 Token IDs
"""
try:
# 调用 tokenizer,直接传入列表即可实现批量处理
encoding = tokenizer(
request.texts,
padding=request.padding,
truncation=request.truncation,
max_length=request.max_length,
add_special_tokens=request.add_special_tokens,
return_tensors="pt"
)
# 将 PyTorch Tensor 转换为 Python 列表以便 JSON 序列化
# input_ids 的形状通常是 [batch_size, sequence_length]
batch_token_ids = encoding.input_ids.tolist()
# 计算每个文本的长度
lengths = [len(tokens) for tokens in batch_token_ids]
return {
"batch_size": len(request.texts),
"token_ids_batch": batch_token_ids,
"lengths": lengths
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/decode")
async def decode_tokens(request: DecodeRequest):
"""
将 Token IDs 还原为文本
"""
try:
text = tokenizer.decode(
request.token_ids,
skip_special_tokens=request.skip_special_tokens
)
return {
"token_ids": request.token_ids,
"text": text,
"count": len(request.token_ids)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 5. 启动服务
if __name__ == "__main__":
# 启动命令:uvicorn tokenizer_api:app --host 0.0.0.0 --port 8001
import uvicorn
port = int(os.getenv("Tokenizer_API_PORT", 8007))
uvicorn.run(app, host="0.0.0.0", port=port)
|