from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel # Initialize FastAPI app = FastAPI(title="Medical Chatbot API") # Global variables for the model and tokenizer model = None tokenizer = None # Define the structure of the incoming request class QueryRequest(BaseModel): prompt: str max_tokens: int = 150 @app.on_event("startup") def load_model(): global model, tokenizer print("Loading model onto GPU...") # 1. 4-bit config to fit the GPU bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) base_model_id = "mistralai/Mistral-7B-Instruct-v0.2" # 2. Load Base Model model = AutoModelForCausalLM.from_pretrained( base_model_id, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) tokenizer = AutoTokenizer.from_pretrained(base_model_id) # 3. Attach Medical Adapters adapter_id = "Amrender/Medical_Chatbot" model = PeftModel.from_pretrained(model, adapter_id) print("Model loaded successfully!") @app.post("/generate") async def generate_response(request: QueryRequest): if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model is still loading.") try: # Format the input inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda") # Generate the output outputs = model.generate(**inputs, max_new_tokens=request.max_tokens) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Strip the prompt from the response if necessary final_answer = response_text.replace(request.prompt, "").strip() return {"response": final_answer} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "active", "model_loaded": model is not None}