# api_fastapi.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch import uvicorn app = FastAPI(title="Mistral API") class ChatRequest(BaseModel): prompt: str max_tokens: int = 500 temperature: float = 0.7 # Global model instance MODEL = None TOKENIZER = None @app.on_event("startup") async def load_model(): global MODEL, TOKENIZER try: TOKENIZER = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") MODEL = AutoModelForCausalLM.from_pretrained( "TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto", load_in_8bit=True ) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") @app.get("/health") async def health(): return {"status": "healthy", "model_loaded": MODEL is not None} @app.post("/chat") async def chat_completion(request: ChatRequest): if MODEL is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Format prompt formatted_prompt = f"[INST] {request.prompt} [/INST]" # Tokenize inputs = TOKENIZER(formatted_prompt, return_tensors="pt").to(MODEL.device) # Generate with torch.no_grad(): outputs = MODEL.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature, do_sample=True, top_p=0.95 ) # Decode response = TOKENIZER.decode(outputs[0], skip_special_tokens=True) response = response.split("[/INST]")[-1].strip() return { "response": response, "tokens_generated": len(outputs[0]) - len(inputs.input_ids[0]) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/batch_chat") async def batch_chat(requests: list[ChatRequest]): """Process multiple prompts at once""" responses = [] for req in requests: result = await chat_completion(req) responses.append(result) return {"responses": responses} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)