Spaces:
Sleeping
Sleeping
File size: 2,404 Bytes
1e639fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | # api_fastapi.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
app = FastAPI(title="Mistral API")
class ChatRequest(BaseModel):
prompt: str
max_tokens: int = 500
temperature: float = 0.7
# Global model instance
MODEL = None
TOKENIZER = None
@app.on_event("startup")
async def load_model():
global MODEL, TOKENIZER
try:
TOKENIZER = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
MODEL = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": MODEL is not None}
@app.post("/chat")
async def chat_completion(request: ChatRequest):
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format prompt
formatted_prompt = f"[INST] {request.prompt} [/INST]"
# Tokenize
inputs = TOKENIZER(formatted_prompt, return_tensors="pt").to(MODEL.device)
# Generate
with torch.no_grad():
outputs = MODEL.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
top_p=0.95
)
# Decode
response = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
response = response.split("[/INST]")[-1].strip()
return {
"response": response,
"tokens_generated": len(outputs[0]) - len(inputs.input_ids[0])
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_chat")
async def batch_chat(requests: list[ChatRequest]):
"""Process multiple prompts at once"""
responses = []
for req in requests:
result = await chat_completion(req)
responses.append(result)
return {"responses": responses}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |