Amrender's picture
Create app.py
4403fc1 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Initialize FastAPI
app = FastAPI(title="Medical Chatbot API")
# Global variables for the model and tokenizer
model = None
tokenizer = None
# Define the structure of the incoming request
class QueryRequest(BaseModel):
prompt: str
max_tokens: int = 150
@app.on_event("startup")
def load_model():
global model, tokenizer
print("Loading model onto GPU...")
# 1. 4-bit config to fit the GPU
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# 2. Load Base Model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
# 3. Attach Medical Adapters
adapter_id = "Amrender/Medical_Chatbot"
model = PeftModel.from_pretrained(model, adapter_id)
print("Model loaded successfully!")
@app.post("/generate")
async def generate_response(request: QueryRequest):
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model is still loading.")
try:
# Format the input
inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda")
# Generate the output
outputs = model.generate(**inputs, max_new_tokens=request.max_tokens)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Strip the prompt from the response if necessary
final_answer = response_text.replace(request.prompt, "").strip()
return {"response": final_answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "active", "model_loaded": model is not None}