jeevanrushi07 commited on
Commit
022eb1b
Β·
verified Β·
1 Parent(s): a31c4bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -1,15 +1,22 @@
1
- import gradio as gr
2
- import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import os
 
5
 
6
- # Load Hugging Face token from secrets
7
- hf_token = os.environ.get("HF_TOKEN", "").strip()
8
  model_name = "jeevanrushi07/gemma-medical-assistant"
9
 
10
- # Load model and tokenizer
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=hf_token)
 
 
 
 
 
13
 
14
  def generate_report(prompt):
15
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
@@ -21,16 +28,14 @@ def generate_report(prompt):
21
  top_p=0.9,
22
  eos_token_id=tokenizer.eos_token_id
23
  )
24
- text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
25
- return text
26
 
27
- # Gradio UI
28
  iface = gr.Interface(
29
  fn=generate_report,
30
- inputs=gr.Textbox(lines=10, placeholder="Enter patient info and symptoms..."),
31
  outputs="textbox",
32
- title="Medical Chatbot",
33
- description="Enter patient data and receive AI-generated medical report."
34
  )
35
 
36
  iface.launch()
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
  import os
4
+ import gradio as gr
5
 
6
+
7
+ hf_token = os.environ.get("HF_TOKEN").strip()
8
  model_name = "jeevanrushi07/gemma-medical-assistant"
9
 
10
+ # Directory to offload layers to CPU
11
+ offload_dir = "/tmp/model_offload"
12
+
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ device_map="auto",
17
+ offload_folder=offload_dir, # This is required for large models
18
+ token=hf_token
19
+ )
20
 
21
  def generate_report(prompt):
22
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
 
28
  top_p=0.9,
29
  eos_token_id=tokenizer.eos_token_id
30
  )
31
+ return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
32
 
 
33
  iface = gr.Interface(
34
  fn=generate_report,
35
+ inputs=gr.Textbox(lines=10, placeholder="Enter patient info..."),
36
  outputs="textbox",
37
+ title="Medical Chatbot"
 
38
  )
39
 
40
  iface.launch()
41
+