airsmodel / app.py
airsltd's picture
Update app.py
521b755 verified
from fastapi import FastAPI, HTTPException
import os
from dotenv import load_dotenv
# 导入 utils 模块
from utils.chat_request import ChatRequest
from utils.chat_response import create_chat_response, ChatResponse
from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest
# 全局变量
model_name = None
pipe = None
tokenizer = None
# 初始化 FastAPI 应用
app = FastAPI(title="HF-Model-Runner API", version="0.0.1")
@app.on_event("startup")
async def startup_event():
"""
应用启动时初始化 pipeline
"""
global pipe, tokenizer, model_name
# 加载 .env 文件
load_dotenv()
# 从 .env 获取默认模型名称,如果没有则使用默认值
default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it")
print(f"应用启动,正在初始化模型: {default_model}")
try:
pipe, tokenizer, success = initialize_pipeline(default_model)
if success:
model_name = default_model
print(f"✓ 模型 {default_model} 初始化成功")
else:
print(f"✗ 模型 {default_model} 初始化失败")
except Exception as e:
print(f"✗ 启动时模型初始化失败: {e}")
@app.get("/")
async def read_root():
return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."}
@app.post("/v1/download")
async def download_model_endpoint(request: DownloadRequest):
"""
下载指定的 HuggingFace 模型
"""
global pipe, tokenizer, model_name
try:
success, message = download_model(request.model)
if success:
# 下载成功后,直接初始化该模型
pipe, tokenizer, init_success = initialize_pipeline(request.model)
if init_success:
model_name = request.model
return {
"status": "success",
"message": message,
"loaded": True,
"current_model": model_name
}
else:
return {
"status": "success",
"message": message,
"loaded": False,
"error": "模型下载成功但初始化失败"
}
else:
raise HTTPException(status_code=500, detail=message)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""
OpenAI 兼容的聊天完成接口
"""
global pipe, tokenizer, model_name
# 检查模型是否匹配,如果请求的模型与当前加载的模型不同,需要重新初始化
if request.model != model_name:
pipe, tokenizer, success = initialize_pipeline(request.model)
if not success:
raise HTTPException(status_code=500, detail="模型初始化失败")
model_name = request.model
try:
return create_chat_response(request, pipe, tokenizer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))