arjunkmoorthy commited on
Commit
8c048e4
·
verified ·
1 Parent(s): 3f41316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -4,34 +4,40 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import os
6
 
7
- # Set Hugging Face cache to a safe location in Hugging Face Space
8
- os.environ["HF_HOME"] = "/tmp"
9
- os.environ["HF_DATASETS_CACHE"] = "/tmp"
10
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
11
- os.environ["HF_MODULES_CACHE"] = "/tmp"
12
-
13
-
14
 
15
  app = FastAPI()
16
 
17
- # Load ChatDoctor model
18
  model_id = "Dashanka/medical-chatbot-Llama3.1-8B-instruct-4bit"
19
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
22
  trust_remote_code=True,
 
23
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
24
  )
25
 
26
-
27
-
28
-
29
  class ChatInput(BaseModel):
30
  message: str
31
 
 
32
  @app.post("/chat")
33
  async def chat_handler(input: ChatInput):
34
- inputs = tokenizer(input.message, return_tensors="pt")
35
- outputs = model.generate(**inputs, max_new_tokens=200)
 
 
 
 
 
 
36
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  return {"response": response}
 
4
  import torch
5
  import os
6
 
7
+ # Set safe cache directory
8
+ HF_CACHE = "/tmp/hf_cache"
9
+ os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
10
+ os.environ["HF_HOME"] = HF_CACHE
11
+ os.environ["HF_DATASETS_CACHE"] = HF_CACHE
12
+ os.environ["HF_MODULES_CACHE"] = HF_CACHE
13
+ os.makedirs(HF_CACHE, exist_ok=True)
14
 
15
  app = FastAPI()
16
 
 
17
  model_id = "Dashanka/medical-chatbot-Llama3.1-8B-instruct-4bit"
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=HF_CACHE)
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
22
  trust_remote_code=True,
23
+ cache_dir=HF_CACHE,
24
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
  )
26
 
27
+ # ✅ User input schema
 
 
28
  class ChatInput(BaseModel):
29
  message: str
30
 
31
+ # ✅ POST endpoint for symptom chat
32
  @app.post("/chat")
33
  async def chat_handler(input: ChatInput):
34
+ # Add provider-like system prompt
35
+ prompt = f"""You are a kind, attentive oncology provider speaking to a patient. Ask one follow-up question at a time to triage their symptoms.
36
+
37
+ Patient: {input.message}
38
+ Provider:"""
39
+
40
+ inputs = tokenizer(prompt, return_tensors="pt")
41
+ outputs = model.generate(**inputs, max_new_tokens=300)
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
  return {"response": response}