Spaces:
Sleeping
Sleeping
| # api_fastapi.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import uvicorn | |
| app = FastAPI(title="Mistral API") | |
| class ChatRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 500 | |
| temperature: float = 0.7 | |
| # Global model instance | |
| MODEL = None | |
| TOKENIZER = None | |
| async def load_model(): | |
| global MODEL, TOKENIZER | |
| try: | |
| TOKENIZER = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| MODEL = AutoModelForCausalLM.from_pretrained( | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| load_in_8bit=True | |
| ) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| async def health(): | |
| return {"status": "healthy", "model_loaded": MODEL is not None} | |
| async def chat_completion(request: ChatRequest): | |
| if MODEL is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Format prompt | |
| formatted_prompt = f"[INST] {request.prompt} [/INST]" | |
| # Tokenize | |
| inputs = TOKENIZER(formatted_prompt, return_tensors="pt").to(MODEL.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = MODEL.generate( | |
| **inputs, | |
| max_new_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| do_sample=True, | |
| top_p=0.95 | |
| ) | |
| # Decode | |
| response = TOKENIZER.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("[/INST]")[-1].strip() | |
| return { | |
| "response": response, | |
| "tokens_generated": len(outputs[0]) - len(inputs.input_ids[0]) | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def batch_chat(requests: list[ChatRequest]): | |
| """Process multiple prompts at once""" | |
| responses = [] | |
| for req in requests: | |
| result = await chat_completion(req) | |
| responses.append(result) | |
| return {"responses": responses} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |