from fastapi import FastAPI, HTTPException import os from dotenv import load_dotenv # 导入 utils 模块 from utils.chat_request import ChatRequest from utils.chat_response import create_chat_response, ChatResponse from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest # 全局变量 model_name = None pipe = None tokenizer = None # 初始化 FastAPI 应用 app = FastAPI(title="HF-Model-Runner API", version="0.0.1") @app.on_event("startup") async def startup_event(): """ 应用启动时初始化 pipeline """ global pipe, tokenizer, model_name # 加载 .env 文件 load_dotenv() # 从 .env 获取默认模型名称,如果没有则使用默认值 default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it") print(f"应用启动,正在初始化模型: {default_model}") try: pipe, tokenizer, success = initialize_pipeline(default_model) if success: model_name = default_model print(f"✓ 模型 {default_model} 初始化成功") else: print(f"✗ 模型 {default_model} 初始化失败") except Exception as e: print(f"✗ 启动时模型初始化失败: {e}") @app.get("/") async def read_root(): return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."} @app.post("/v1/download") async def download_model_endpoint(request: DownloadRequest): """ 下载指定的 HuggingFace 模型 """ global pipe, tokenizer, model_name try: success, message = download_model(request.model) if success: # 下载成功后,直接初始化该模型 pipe, tokenizer, init_success = initialize_pipeline(request.model) if init_success: model_name = request.model return { "status": "success", "message": message, "loaded": True, "current_model": model_name } else: return { "status": "success", "message": message, "loaded": False, "error": "模型下载成功但初始化失败" } else: raise HTTPException(status_code=500, detail=message) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """ OpenAI 兼容的聊天完成接口 """ global pipe, tokenizer, model_name # 检查模型是否匹配,如果请求的模型与当前加载的模型不同,需要重新初始化 if request.model != model_name: pipe, tokenizer, success = initialize_pipeline(request.model) if not success: raise HTTPException(status_code=500, detail="模型初始化失败") model_name = request.model try: return create_chat_response(request, pipe, tokenizer) except Exception as e: raise HTTPException(status_code=500, detail=str(e))