File size: 3,231 Bytes
702fae5 e142333 702fae5 e142333 702fae5 8dd24a3 702fae5 521b755 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from fastapi import FastAPI, HTTPException
import os
from dotenv import load_dotenv
# 导入 utils 模块
from utils.chat_request import ChatRequest
from utils.chat_response import create_chat_response, ChatResponse
from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest
# 全局变量
model_name = None
pipe = None
tokenizer = None
# 初始化 FastAPI 应用
app = FastAPI(title="HF-Model-Runner API", version="0.0.1")
@app.on_event("startup")
async def startup_event():
"""
应用启动时初始化 pipeline
"""
global pipe, tokenizer, model_name
# 加载 .env 文件
load_dotenv()
# 从 .env 获取默认模型名称,如果没有则使用默认值
default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it")
print(f"应用启动,正在初始化模型: {default_model}")
try:
pipe, tokenizer, success = initialize_pipeline(default_model)
if success:
model_name = default_model
print(f"✓ 模型 {default_model} 初始化成功")
else:
print(f"✗ 模型 {default_model} 初始化失败")
except Exception as e:
print(f"✗ 启动时模型初始化失败: {e}")
@app.get("/")
async def read_root():
return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."}
@app.post("/v1/download")
async def download_model_endpoint(request: DownloadRequest):
"""
下载指定的 HuggingFace 模型
"""
global pipe, tokenizer, model_name
try:
success, message = download_model(request.model)
if success:
# 下载成功后,直接初始化该模型
pipe, tokenizer, init_success = initialize_pipeline(request.model)
if init_success:
model_name = request.model
return {
"status": "success",
"message": message,
"loaded": True,
"current_model": model_name
}
else:
return {
"status": "success",
"message": message,
"loaded": False,
"error": "模型下载成功但初始化失败"
}
else:
raise HTTPException(status_code=500, detail=message)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""
OpenAI 兼容的聊天完成接口
"""
global pipe, tokenizer, model_name
# 检查模型是否匹配,如果请求的模型与当前加载的模型不同,需要重新初始化
if request.model != model_name:
pipe, tokenizer, success = initialize_pipeline(request.model)
if not success:
raise HTTPException(status_code=500, detail="模型初始化失败")
model_name = request.model
try:
return create_chat_response(request, pipe, tokenizer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|