arjunkmoorthy commited on
Commit
97bceda
·
verified ·
1 Parent(s): c10b947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -1,43 +1,36 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
  import os
6
 
7
- # ✅ Set safe cache directory
8
  HF_CACHE = "/tmp/hf_cache"
9
  os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
10
  os.environ["HF_HOME"] = HF_CACHE
11
- os.environ["HF_DATASETS_CACHE"] = HF_CACHE
12
- os.environ["HF_MODULES_CACHE"] = HF_CACHE
13
  os.makedirs(HF_CACHE, exist_ok=True)
14
 
15
  app = FastAPI()
16
 
17
  model_id = "srikar-v05/phi3-Mini-Medical-Chat"
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=HF_CACHE)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_id,
22
- trust_remote_code=True,
23
- cache_dir=HF_CACHE,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
  )
26
 
27
- # User input schema
 
 
28
  class ChatInput(BaseModel):
29
  message: str
30
 
31
- # ✅ POST endpoint for symptom chat
32
  @app.post("/chat")
33
  async def chat_handler(input: ChatInput):
34
- # Add provider-like system prompt
35
- prompt = f"""You are a kind, attentive oncology provider speaking to a patient. Ask one follow-up question at a time to triage their symptoms.
36
-
37
- Patient: {input.message}
38
- Provider:"""
39
-
40
- inputs = tokenizer(prompt, return_tensors="pt")
41
  outputs = model.generate(**inputs, max_new_tokens=300)
42
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
  return {"response": response}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from unsloth import FastModel
 
4
  import os
5
 
 
6
  HF_CACHE = "/tmp/hf_cache"
7
  os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
8
  os.environ["HF_HOME"] = HF_CACHE
 
 
9
  os.makedirs(HF_CACHE, exist_ok=True)
10
 
11
  app = FastAPI()
12
 
13
  model_id = "srikar-v05/phi3-Mini-Medical-Chat"
14
+ model, tokenizer = FastModel.from_pretrained(
15
+ model_name = model_id,
16
+ load_in_4bit = True,
17
+ max_seq_length = 2048,
 
 
 
18
  )
19
 
20
+ # Optional: optimise for inference
21
+ FastModel.for_inference(model)
22
+
23
  class ChatInput(BaseModel):
24
  message: str
25
 
 
26
  @app.post("/chat")
27
  async def chat_handler(input: ChatInput):
28
+ prompt = (
29
+ "You are a kind, attentive oncology provider speaking to a patient.\n"
30
+ "Ask one follow-up question at a time to triage their symptoms.\n\n"
31
+ f"Patient: {input.message}\nProvider:"
32
+ )
33
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
34
  outputs = model.generate(**inputs, max_new_tokens=300)
35
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
36
  return {"response": response}