Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse, Response | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| import asyncio | |
| import logging | |
| from contextlib import asynccontextmanager | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 共享状态,用于存储模型和加载状态 | |
| state = { | |
| "translator": None, | |
| "model_loaded": False, | |
| "model_loading": False | |
| } | |
| async def load_model(): | |
| """异步加载模型""" | |
| if state["model_loaded"] or state["model_loading"]: | |
| return | |
| state["model_loading"] = True | |
| logger.info("开始加载模型...") | |
| try: | |
| state["translator"] = pipeline("translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh") | |
| state["model_loaded"] = True | |
| logger.info("模型加载成功。") | |
| except Exception as e: | |
| logger.error(f"模型加载失败: {e}") | |
| finally: | |
| state["model_loading"] = False | |
| async def lifespan(app: FastAPI): | |
| # Load the model on startup | |
| asyncio.create_task(load_model()) | |
| yield | |
| # Clean up the model and release the resources | |
| state["translator"] = None | |
| state["model_loaded"] = False | |
| logger.info("模型已卸载。") | |
| app = FastAPI(lifespan=lifespan) | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| class TranslationResponse(BaseModel): | |
| translated_text: str | |
| async def translate_text(request: TranslationRequest): | |
| if not state["model_loaded"]: | |
| return JSONResponse(content={"message": "Model is not loaded yet"}, status_code=503) | |
| # Split the text into chunks | |
| text_chunks = split_text(request.text) | |
| # Translate each chunk | |
| translated_chunks = [] | |
| for chunk in text_chunks: | |
| # The translator returns a list of dictionaries | |
| translated_chunk = state["translator"](chunk, max_length=512) | |
| translated_chunks.append(translated_chunk[0]['translation_text']) | |
| # Join the translated chunks | |
| translated_text = "".join(translated_chunks) | |
| return {"translated_text": translated_text} | |
| def split_text(text: str, max_length: int = 512): | |
| # A simple way to split text by chunks of max_length | |
| # A more sophisticated approach could split by sentences. | |
| text_chunks = [] | |
| while len(text) > max_length: | |
| # Find the last space to avoid splitting words | |
| split_at = text.rfind(' ', 0, max_length) | |
| if split_at == -1: | |
| # No space found, split at max_length | |
| split_at = max_length | |
| text_chunks.append(text[:split_at]) | |
| text = text[split_at:].lstrip() | |
| text_chunks.append(text) | |
| return text_chunks | |
| async def health_check(): | |
| status_code = 200 if state["model_loaded"] else 503 | |
| return Response(content={"model_loaded": state["model_loaded"]}, status_code=status_code) | |
| async def read_root(): | |
| return {"message": "Welcome to the translation API"} | |
| # if __name__ == '__main__': | |
| # import uvicorn | |
| # uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |