openai / app.py
tianruci's picture
Update app.py
27b57e7 verified
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import APIKeyHeader
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import torch
import os
from dotenv import load_dotenv
# 设置多个 Hugging Face 相关环境变量以覆盖默认缓存路径,避免权限问题
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HOME"] = "/tmp" # 重定向 home 目录到可写路径
# 手动创建缓存目录并尝试设置权限
cache_dir = os.environ["HF_HOME"]
os.makedirs(cache_dir, exist_ok=True)
try:
os.chmod(cache_dir, 0o777) # 尝试设置可写权限,如果容器允许
except PermissionError:
pass # 如果权限不足,忽略并继续(库会使用现有目录)
load_dotenv() # 加载 .env 文件
app = FastAPI()
# 从环境变量读取配置
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEYS = os.getenv("API_KEYS", "").split(",")
if not HF_TOKEN:
raise ValueError("请设置环境变量 HF_TOKEN")
if not API_KEYS:
raise ValueError("请设置环境变量 API_KEYS")
# 登录 Hugging Face,添加参数以减少不必要的写入操作
login(token=HF_TOKEN, add_to_git_credential=False)
# 使用 DeepSeek 7B(开放模型)
model_name = "deepseek-ai/deepseek-llm-7b"
device = "cpu" # 或 "auto" 自动检测 CPU/GPU
# 加载 Tokenizer 和 Model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
# 认证依赖函数
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def get_api_key(api_key: str = Depends(api_key_header)):
if api_key not in API_KEYS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key"
)
return api_key
@app.post("/v1/chat/completions")
async def chat_completions(request: dict, api_key: str = Depends(get_api_key)):
prompt = request["messages"][0]["content"]
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"choices": [{"message": {"content": response}}]}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)