VideoAgent-AX650N / VideoAgent /_server /tokenizer_server.py
H022329's picture
Upload folder using huggingface_hub
9cf08e9 verified
Raw
History Blame Contribute Delete
3.93 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from transformers import AutoTokenizer
import uvicorn
import os
from dotenv import load_dotenv
load_dotenv()
# 1. 初始化 FastAPI
app = FastAPI(title="Qwen3 Tokenizer API", version="1.0")
# 2. 全局加载 Tokenizer (避免每次请求都加载)
# 如果是本地文件,请替换为本地路径,例如 "./qwen3_tokenizer"
MODEL_PATH = os.getenv("Tokenizer_MODEL_PATH", "")
print("MODEL_PATH :",MODEL_PATH )
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=False)
# 3. 定义请求数据结构
class EncodeRequest(BaseModel):
text: str
add_special_tokens: bool = True # 是否添加 <\|im_start\|> 等特殊标记
class DecodeRequest(BaseModel):
token_ids: List[int]
skip_special_tokens: bool = True # 是否跳过特殊标记
class BatchEncodeRequest(BaseModel):
texts: List[str] # 接收文本列表
padding: bool = False # 是否自动填充到最长序列
truncation: bool = True # 是否截断超长文本
max_length: Optional[int] = None # 最大长度限制
add_special_tokens: bool = True
# 4. 定义 API 路由
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {"status": "running", "model": tokenizer.name_or_path}
@app.post("/encode")
async def encode_text(request: EncodeRequest):
"""
将文本转换为 Token IDs
"""
try:
# 调用 tokenizer
encoding = tokenizer(
request.text,
add_special_tokens=request.add_special_tokens,
return_tensors="pt" # 返回 PyTorch 张量,也可以选 "np"
)
token_ids = encoding.input_ids[0].tolist()
return {
"text": request.text,
"token_ids": token_ids,
"count": len(token_ids), # 返回长度,方便前端计算
"special_tokens_added": request.add_special_tokens
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_encode")
async def batch_encode(request: BatchEncodeRequest):
"""
批量将文本列表转换为 Token IDs
"""
try:
# 调用 tokenizer,直接传入列表即可实现批量处理
encoding = tokenizer(
request.texts,
padding=request.padding,
truncation=request.truncation,
max_length=request.max_length,
add_special_tokens=request.add_special_tokens,
return_tensors="pt"
)
# 将 PyTorch Tensor 转换为 Python 列表以便 JSON 序列化
# input_ids 的形状通常是 [batch_size, sequence_length]
batch_token_ids = encoding.input_ids.tolist()
# 计算每个文本的长度
lengths = [len(tokens) for tokens in batch_token_ids]
return {
"batch_size": len(request.texts),
"token_ids_batch": batch_token_ids,
"lengths": lengths
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/decode")
async def decode_tokens(request: DecodeRequest):
"""
将 Token IDs 还原为文本
"""
try:
text = tokenizer.decode(
request.token_ids,
skip_special_tokens=request.skip_special_tokens
)
return {
"token_ids": request.token_ids,
"text": text,
"count": len(request.token_ids)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 5. 启动服务
if __name__ == "__main__":
# 启动命令:uvicorn tokenizer_api:app --host 0.0.0.0 --port 8001
import uvicorn
port = int(os.getenv("Tokenizer_API_PORT", 8007))
uvicorn.run(app, host="0.0.0.0", port=port)