airsmodel / utils /model.py
tanbushi's picture
update
702fae5
import os
import sys
from pathlib import Path
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from fastapi import HTTPException
from pydantic import BaseModel
class DownloadRequest(BaseModel):
model: str
def check_model(model_name):
"""
检查模型是否存在
参数: model_name - 从 request 传递过来的模型名称
返回: (model_name, cache_dir, success)
"""
cache_dir = "./my_model_cache"
# 检查模型是否已存在于缓存中
model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}"
snapshot_path = model_path / "snapshots"
if snapshot_path.exists() and any(snapshot_path.iterdir()):
print(f"✓ 模型 {model_name} 已存在于缓存中")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
return model_name, cache_dir, True
except Exception as e:
print(f"⚠ 加载现有模型失败: {e}")
return model_name, cache_dir, False
else:
raise HTTPException(status_code=404, detail=f"模型 `{model_name}` 不存在,请先下载")
def download_model(model_name):
"""
下载指定的模型
参数: model_name - 要下载的模型名称
返回: (success, message)
"""
cache_dir = "./my_model_cache"
print(f"开始下载模型: {model_name}")
print(f"缓存目录: {cache_dir}")
# 登录 Hugging Face(可选,用于需要认证的模型)
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
try:
print("登录 Hugging Face...")
login(token=token)
print("✓ HuggingFace 登录成功!")
except Exception as e:
print(f"⚠ 登录失败: {e}")
print("继续使用公开模型")
else:
print("ℹ 未设置 HUGGINGFACE_TOKEN - 仅使用公开模型")
try:
# 下载 tokenizer
print("正在下载 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
print("✓ Tokenizer 下载成功!")
# 下载模型
print("正在下载模型...")
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
print("✓ 模型下载成功!")
print(f"✓ 模型和 tokenizer 已成功下载到 {cache_dir}")
return True, f"模型 {model_name} 下载成功"
except Exception as e:
print(f"✗ 下载模型时出错: {e}")
return False, f"下载失败: {str(e)}"
def initialize_pipeline(model_name):
"""
使用模型初始化 pipeline
参数: model_name - 从 request 传递过来的模型名称
返回: (pipe, tokenizer, success)
"""
model_name, cache_dir, success = check_model(model_name)
if not success:
return None, None, False
try:
# 确保 tokenizer 已加载
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
print(f"使用 {model_name} 初始化 pipeline...")
# 移除 cache_dir 参数,只传递给 AutoTokenizer 和 AutoModelForCausalLM
pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer)
print("✓ Pipeline 初始化成功!")
return pipe, tokenizer, True
except Exception as e:
print(f"✗ Pipeline 初始化失败: {e}")
return None, None, False