| from fastapi import FastAPI, Depends, HTTPException, status |
| from fastapi.security import APIKeyHeader |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import login |
| import torch |
| import os |
| from dotenv import load_dotenv |
|
|
| |
| os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
| os.environ["HF_HUB_CACHE"] = "/tmp/.cache/huggingface" |
| os.environ["HOME"] = "/tmp" |
|
|
| |
| cache_dir = os.environ["HF_HOME"] |
| os.makedirs(cache_dir, exist_ok=True) |
| try: |
| os.chmod(cache_dir, 0o777) |
| except PermissionError: |
| pass |
|
|
| load_dotenv() |
|
|
| app = FastAPI() |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| API_KEYS = os.getenv("API_KEYS", "").split(",") |
|
|
| if not HF_TOKEN: |
| raise ValueError("请设置环境变量 HF_TOKEN") |
| if not API_KEYS: |
| raise ValueError("请设置环境变量 API_KEYS") |
|
|
| |
| login(token=HF_TOKEN, add_to_git_credential=False) |
|
|
| |
| model_name = "deepseek-ai/deepseek-llm-7b" |
| device = "cpu" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) |
|
|
| |
| api_key_header = APIKeyHeader(name="Authorization", auto_error=False) |
| async def get_api_key(api_key: str = Depends(api_key_header)): |
| if api_key not in API_KEYS: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API Key" |
| ) |
| return api_key |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: dict, api_key: str = Depends(get_api_key)): |
| prompt = request["messages"][0]["content"] |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7) |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return {"choices": [{"message": {"content": response}}]} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |