Mistral_Test / api_fastapi.py
eesfeg's picture
Add application file
1e639fb
# api_fastapi.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
app = FastAPI(title="Mistral API")
class ChatRequest(BaseModel):
prompt: str
max_tokens: int = 500
temperature: float = 0.7
# Global model instance
MODEL = None
TOKENIZER = None
@app.on_event("startup")
async def load_model():
global MODEL, TOKENIZER
try:
TOKENIZER = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
MODEL = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": MODEL is not None}
@app.post("/chat")
async def chat_completion(request: ChatRequest):
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format prompt
formatted_prompt = f"[INST] {request.prompt} [/INST]"
# Tokenize
inputs = TOKENIZER(formatted_prompt, return_tensors="pt").to(MODEL.device)
# Generate
with torch.no_grad():
outputs = MODEL.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
top_p=0.95
)
# Decode
response = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
response = response.split("[/INST]")[-1].strip()
return {
"response": response,
"tokens_generated": len(outputs[0]) - len(inputs.input_ids[0])
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_chat")
async def batch_chat(requests: list[ChatRequest]):
"""Process multiple prompts at once"""
responses = []
for req in requests:
result = await chat_completion(req)
responses.append(result)
return {"responses": responses}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)