|
|
from fastapi import FastAPI, HTTPException |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
from utils.chat_request import ChatRequest |
|
|
from utils.chat_response import create_chat_response, ChatResponse |
|
|
from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest |
|
|
|
|
|
|
|
|
model_name = None |
|
|
pipe = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
app = FastAPI(title="HF-Model-Runner API", version="0.0.1") |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
""" |
|
|
应用启动时初始化 pipeline |
|
|
""" |
|
|
global pipe, tokenizer, model_name |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it") |
|
|
print(f"应用启动,正在初始化模型: {default_model}") |
|
|
|
|
|
try: |
|
|
pipe, tokenizer, success = initialize_pipeline(default_model) |
|
|
if success: |
|
|
model_name = default_model |
|
|
print(f"✓ 模型 {default_model} 初始化成功") |
|
|
else: |
|
|
print(f"✗ 模型 {default_model} 初始化失败") |
|
|
except Exception as e: |
|
|
print(f"✗ 启动时模型初始化失败: {e}") |
|
|
|
|
|
@app.get("/") |
|
|
async def read_root(): |
|
|
return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."} |
|
|
|
|
|
@app.post("/v1/download") |
|
|
async def download_model_endpoint(request: DownloadRequest): |
|
|
""" |
|
|
下载指定的 HuggingFace 模型 |
|
|
""" |
|
|
global pipe, tokenizer, model_name |
|
|
|
|
|
try: |
|
|
success, message = download_model(request.model) |
|
|
if success: |
|
|
|
|
|
pipe, tokenizer, init_success = initialize_pipeline(request.model) |
|
|
if init_success: |
|
|
model_name = request.model |
|
|
return { |
|
|
"status": "success", |
|
|
"message": message, |
|
|
"loaded": True, |
|
|
"current_model": model_name |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"status": "success", |
|
|
"message": message, |
|
|
"loaded": False, |
|
|
"error": "模型下载成功但初始化失败" |
|
|
} |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail=message) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatResponse) |
|
|
async def chat_completions(request: ChatRequest): |
|
|
""" |
|
|
OpenAI 兼容的聊天完成接口 |
|
|
""" |
|
|
global pipe, tokenizer, model_name |
|
|
|
|
|
|
|
|
if request.model != model_name: |
|
|
pipe, tokenizer, success = initialize_pipeline(request.model) |
|
|
if not success: |
|
|
raise HTTPException(status_code=500, detail="模型初始化失败") |
|
|
model_name = request.model |
|
|
|
|
|
try: |
|
|
return create_chat_response(request, pipe, tokenizer) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|