Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # Initialize FastAPI | |
| app = FastAPI(title="Medical Chatbot API") | |
| # Global variables for the model and tokenizer | |
| model = None | |
| tokenizer = None | |
| # Define the structure of the incoming request | |
| class QueryRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 150 | |
| def load_model(): | |
| global model, tokenizer | |
| print("Loading model onto GPU...") | |
| # 1. 4-bit config to fit the GPU | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| base_model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
| # 2. Load Base Model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| # 3. Attach Medical Adapters | |
| adapter_id = "Amrender/Medical_Chatbot" | |
| model = PeftModel.from_pretrained(model, adapter_id) | |
| print("Model loaded successfully!") | |
| async def generate_response(request: QueryRequest): | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model is still loading.") | |
| try: | |
| # Format the input | |
| inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda") | |
| # Generate the output | |
| outputs = model.generate(**inputs, max_new_tokens=request.max_tokens) | |
| response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Strip the prompt from the response if necessary | |
| final_answer = response_text.replace(request.prompt, "").strip() | |
| return {"response": final_answer} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "active", "model_loaded": model is not None} |