File size: 3,231 Bytes
702fae5
 
 
 
 
 
 
 
 
 
 
 
 
e142333
 
 
 
702fae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e142333
 
702fae5
 
 
8dd24a3
702fae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521b755
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from fastapi import FastAPI, HTTPException
import os
from dotenv import load_dotenv

# 导入 utils 模块
from utils.chat_request import ChatRequest
from utils.chat_response import create_chat_response, ChatResponse
from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest

# 全局变量
model_name = None
pipe = None
tokenizer = None

# 初始化 FastAPI 应用
app = FastAPI(title="HF-Model-Runner API", version="0.0.1")

@app.on_event("startup")
async def startup_event():
    """
    应用启动时初始化 pipeline
    """
    global pipe, tokenizer, model_name
    
    # 加载 .env 文件
    load_dotenv()
    
    # 从 .env 获取默认模型名称,如果没有则使用默认值
    default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it")
    print(f"应用启动,正在初始化模型: {default_model}")
    
    try:
        pipe, tokenizer, success = initialize_pipeline(default_model)
        if success:
            model_name = default_model
            print(f"✓ 模型 {default_model} 初始化成功")
        else:
            print(f"✗ 模型 {default_model} 初始化失败")
    except Exception as e:
        print(f"✗ 启动时模型初始化失败: {e}")

@app.get("/")
async def read_root():
    return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."}

@app.post("/v1/download")
async def download_model_endpoint(request: DownloadRequest):
    """
    下载指定的 HuggingFace 模型
    """
    global pipe, tokenizer, model_name
    
    try:
        success, message = download_model(request.model)
        if success:
            # 下载成功后,直接初始化该模型
            pipe, tokenizer, init_success = initialize_pipeline(request.model)
            if init_success:
                model_name = request.model
                return {
                    "status": "success", 
                    "message": message,
                    "loaded": True,
                    "current_model": model_name
                }
            else:
                return {
                    "status": "success", 
                    "message": message,
                    "loaded": False,
                    "error": "模型下载成功但初始化失败"
                }
        else:
            raise HTTPException(status_code=500, detail=message)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
    """
    OpenAI 兼容的聊天完成接口
    """
    global pipe, tokenizer, model_name
    
    # 检查模型是否匹配,如果请求的模型与当前加载的模型不同,需要重新初始化
    if request.model != model_name:
        pipe, tokenizer, success = initialize_pipeline(request.model)
        if not success:
            raise HTTPException(status_code=500, detail="模型初始化失败")
        model_name = request.model
    
    try:
        return create_chat_response(request, pipe, tokenizer)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))