TransPlugin / app.py
angre369's picture
fix: handle long text translation by chunking
79ea48e
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from transformers import pipeline
import asyncio
import logging
from contextlib import asynccontextmanager
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 共享状态,用于存储模型和加载状态
state = {
"translator": None,
"model_loaded": False,
"model_loading": False
}
async def load_model():
"""异步加载模型"""
if state["model_loaded"] or state["model_loading"]:
return
state["model_loading"] = True
logger.info("开始加载模型...")
try:
state["translator"] = pipeline("translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh")
state["model_loaded"] = True
logger.info("模型加载成功。")
except Exception as e:
logger.error(f"模型加载失败: {e}")
finally:
state["model_loading"] = False
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the model on startup
asyncio.create_task(load_model())
yield
# Clean up the model and release the resources
state["translator"] = None
state["model_loaded"] = False
logger.info("模型已卸载。")
app = FastAPI(lifespan=lifespan)
class TranslationRequest(BaseModel):
text: str
class TranslationResponse(BaseModel):
translated_text: str
@app.post("/translate", response_model=TranslationResponse)
async def translate_text(request: TranslationRequest):
if not state["model_loaded"]:
return JSONResponse(content={"message": "Model is not loaded yet"}, status_code=503)
# Split the text into chunks
text_chunks = split_text(request.text)
# Translate each chunk
translated_chunks = []
for chunk in text_chunks:
# The translator returns a list of dictionaries
translated_chunk = state["translator"](chunk, max_length=512)
translated_chunks.append(translated_chunk[0]['translation_text'])
# Join the translated chunks
translated_text = "".join(translated_chunks)
return {"translated_text": translated_text}
def split_text(text: str, max_length: int = 512):
# A simple way to split text by chunks of max_length
# A more sophisticated approach could split by sentences.
text_chunks = []
while len(text) > max_length:
# Find the last space to avoid splitting words
split_at = text.rfind(' ', 0, max_length)
if split_at == -1:
# No space found, split at max_length
split_at = max_length
text_chunks.append(text[:split_at])
text = text[split_at:].lstrip()
text_chunks.append(text)
return text_chunks
@app.get("/health")
async def health_check():
status_code = 200 if state["model_loaded"] else 503
return Response(content={"model_loaded": state["model_loaded"]}, status_code=status_code)
@app.get("/")
async def read_root():
return {"message": "Welcome to the translation API"}
# if __name__ == '__main__':
# import uvicorn
# uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)