Amrender commited on
Commit
4403fc1
·
verified ·
1 Parent(s): a579b36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
+ from peft import PeftModel
6
+
7
+ # Initialize FastAPI
8
+ app = FastAPI(title="Medical Chatbot API")
9
+
10
+ # Global variables for the model and tokenizer
11
+ model = None
12
+ tokenizer = None
13
+
14
+ # Define the structure of the incoming request
15
+ class QueryRequest(BaseModel):
16
+ prompt: str
17
+ max_tokens: int = 150
18
+
19
+ @app.on_event("startup")
20
+ def load_model():
21
+ global model, tokenizer
22
+ print("Loading model onto GPU...")
23
+
24
+ # 1. 4-bit config to fit the GPU
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.float16
30
+ )
31
+
32
+ base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
33
+
34
+ # 2. Load Base Model
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ base_model_id,
37
+ quantization_config=bnb_config,
38
+ device_map="auto",
39
+ torch_dtype=torch.float16,
40
+ low_cpu_mem_usage=True
41
+ )
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
44
+
45
+ # 3. Attach Medical Adapters
46
+ adapter_id = "Amrender/Medical_Chatbot"
47
+ model = PeftModel.from_pretrained(model, adapter_id)
48
+ print("Model loaded successfully!")
49
+
50
+ @app.post("/generate")
51
+ async def generate_response(request: QueryRequest):
52
+ if model is None or tokenizer is None:
53
+ raise HTTPException(status_code=503, detail="Model is still loading.")
54
+
55
+ try:
56
+ # Format the input
57
+ inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda")
58
+
59
+ # Generate the output
60
+ outputs = model.generate(**inputs, max_new_tokens=request.max_tokens)
61
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ # Strip the prompt from the response if necessary
64
+ final_answer = response_text.replace(request.prompt, "").strip()
65
+
66
+ return {"response": final_answer}
67
+
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=str(e))
70
+
71
+ @app.get("/health")
72
+ async def health_check():
73
+ return {"status": "active", "model_loaded": model is not None}