gratest / app.py
lainlives's picture
Update app.py
91ff0d2 verified
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Optional
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
HF_TOKEN = os.getenv('HF_TOKEN')
API_KEY = os.getenv('API_KEY')
app = FastAPI()
security = HTTPBearer()
# Load model
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
model_name = "unsloth/Qwen3.5-35B-A3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quant_config,
device_map="auto" # Use 'auto' to let it balance between GPU/CPU
)
def validate_api_key(auth: HTTPAuthorizationCredentials = Depends(security)):
if auth.credentials != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API Key",
headers={"WWW-Authenticate": "Bearer"},
)
return auth.credentials
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str = "qwen-3.5"
messages: List[Message]
temperature: float = 0.7
@app.post("/v1/chat/completions")
async def chat_endpoint(request: ChatRequest, token: str = Depends(validate_api_key)):
# Get the last user message
prompt = request.messages[-1].content
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=200, temperature=request.temperature)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return OpenAI-compatible JSON structure
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": request.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop"
}]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)