File size: 3,537 Bytes
702fae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import sys
from pathlib import Path
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from fastapi import HTTPException
from pydantic import BaseModel


class DownloadRequest(BaseModel):
    model: str


def check_model(model_name):
    """
    检查模型是否存在
    参数: model_name - 从 request 传递过来的模型名称
    返回: (model_name, cache_dir, success)
    """
    cache_dir = "./my_model_cache"
    
    # 检查模型是否已存在于缓存中
    model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}"
    snapshot_path = model_path / "snapshots"
    
    if snapshot_path.exists() and any(snapshot_path.iterdir()):
        print(f"✓ 模型 {model_name} 已存在于缓存中")
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
            return model_name, cache_dir, True
        except Exception as e:
            print(f"⚠ 加载现有模型失败: {e}")
            return model_name, cache_dir, False
    else:
        raise HTTPException(status_code=404, detail=f"模型 `{model_name}` 不存在,请先下载")


def download_model(model_name):
    """
    下载指定的模型
    参数: model_name - 要下载的模型名称
    返回: (success, message)
    """
    cache_dir = "./my_model_cache"
    
    print(f"开始下载模型: {model_name}")
    print(f"缓存目录: {cache_dir}")
    
    # 登录 Hugging Face(可选,用于需要认证的模型)
    token = os.getenv("HUGGINGFACE_TOKEN")
    if token:
        try:
            print("登录 Hugging Face...")
            login(token=token)
            print("✓ HuggingFace 登录成功!")
        except Exception as e:
            print(f"⚠ 登录失败: {e}")
            print("继续使用公开模型")
    else:
        print("ℹ 未设置 HUGGINGFACE_TOKEN - 仅使用公开模型")
    
    try:
        # 下载 tokenizer
        print("正在下载 tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
        print("✓ Tokenizer 下载成功!")
        
        # 下载模型
        print("正在下载模型...")
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
        print("✓ 模型下载成功!")
        
        print(f"✓ 模型和 tokenizer 已成功下载到 {cache_dir}")
        return True, f"模型 {model_name} 下载成功"
        
    except Exception as e:
        print(f"✗ 下载模型时出错: {e}")
        return False, f"下载失败: {str(e)}"


def initialize_pipeline(model_name):
    """
    使用模型初始化 pipeline
    参数: model_name - 从 request 传递过来的模型名称
    返回: (pipe, tokenizer, success)
    """
    model_name, cache_dir, success = check_model(model_name)
    
    if not success:
        return None, None, False
    
    try:
        # 确保 tokenizer 已加载
        tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
        
        print(f"使用 {model_name} 初始化 pipeline...")
        # 移除 cache_dir 参数,只传递给 AutoTokenizer 和 AutoModelForCausalLM
        pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer)
        print("✓ Pipeline 初始化成功!")
        
        return pipe, tokenizer, True
        
    except Exception as e:
        print(f"✗ Pipeline 初始化失败: {e}")
        return None, None, False