Amrender's picture
Create app.py
4403fc1 verified
raw
history blame
2.25 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Initialize FastAPI
app = FastAPI(title="Medical Chatbot API")
# Global variables for the model and tokenizer
model = None
tokenizer = None
# Define the structure of the incoming request
class QueryRequest(BaseModel):
prompt: str
max_tokens: int = 150
@app.on_event("startup")
def load_model():
global model, tokenizer
print("Loading model onto GPU...")
# 1. 4-bit config to fit the GPU
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# 2. Load Base Model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
# 3. Attach Medical Adapters
adapter_id = "Amrender/Medical_Chatbot"
model = PeftModel.from_pretrained(model, adapter_id)
print("Model loaded successfully!")
@app.post("/generate")
async def generate_response(request: QueryRequest):
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model is still loading.")
try:
# Format the input
inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda")
# Generate the output
outputs = model.generate(**inputs, max_new_tokens=request.max_tokens)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Strip the prompt from the response if necessary
final_answer = response_text.replace(request.prompt, "").strip()
return {"response": final_answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "active", "model_loaded": model is not None}