from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response from pydantic import BaseModel from transformers import pipeline import asyncio import logging from contextlib import asynccontextmanager # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 共享状态,用于存储模型和加载状态 state = { "translator": None, "model_loaded": False, "model_loading": False } async def load_model(): """异步加载模型""" if state["model_loaded"] or state["model_loading"]: return state["model_loading"] = True logger.info("开始加载模型...") try: state["translator"] = pipeline("translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh") state["model_loaded"] = True logger.info("模型加载成功。") except Exception as e: logger.error(f"模型加载失败: {e}") finally: state["model_loading"] = False @asynccontextmanager async def lifespan(app: FastAPI): # Load the model on startup asyncio.create_task(load_model()) yield # Clean up the model and release the resources state["translator"] = None state["model_loaded"] = False logger.info("模型已卸载。") app = FastAPI(lifespan=lifespan) class TranslationRequest(BaseModel): text: str class TranslationResponse(BaseModel): translated_text: str @app.post("/translate", response_model=TranslationResponse) async def translate_text(request: TranslationRequest): if not state["model_loaded"]: return JSONResponse(content={"message": "Model is not loaded yet"}, status_code=503) # Split the text into chunks text_chunks = split_text(request.text) # Translate each chunk translated_chunks = [] for chunk in text_chunks: # The translator returns a list of dictionaries translated_chunk = state["translator"](chunk, max_length=512) translated_chunks.append(translated_chunk[0]['translation_text']) # Join the translated chunks translated_text = "".join(translated_chunks) return {"translated_text": translated_text} def split_text(text: str, max_length: int = 512): # A simple way to split text by chunks of max_length # A more sophisticated approach could split by sentences. text_chunks = [] while len(text) > max_length: # Find the last space to avoid splitting words split_at = text.rfind(' ', 0, max_length) if split_at == -1: # No space found, split at max_length split_at = max_length text_chunks.append(text[:split_at]) text = text[split_at:].lstrip() text_chunks.append(text) return text_chunks @app.get("/health") async def health_check(): status_code = 200 if state["model_loaded"] else 503 return Response(content={"model_loaded": state["model_loaded"]}, status_code=status_code) @app.get("/") async def read_root(): return {"message": "Welcome to the translation API"} # if __name__ == '__main__': # import uvicorn # uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)