Muhammadidrees commited on
Commit
c08ace3
·
verified ·
1 Parent(s): ca551bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -69
app.py CHANGED
@@ -1,100 +1,76 @@
1
- # app.py
2
- import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
4
  import os
5
- import torch
6
 
7
- # -----------------------
8
- # Recommended Biomedical Model (Swap here if needed)
9
- # -----------------------
10
- # Options:
11
- # - "stanford-crfm/BioMedLM" (stable, PubMed-trained)
12
- # - "BioMistral/BioMistral-7B" (newer, PubMed + PMC heavy)
13
- # - "epfl-llm/ClinicalCamel" (clinical reporting style)
14
- MODEL_ID = "Muhammadidrees/my-medgamma"
15
 
16
- # -----------------------
17
-
18
- # Load tokenizer + model safely
19
- # -----------------------
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
 
22
- try:
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_ID,
25
- #device_map="auto", # auto GPU/CPU placement
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- low_cpu_mem_usage=True
28
- )
29
- except Exception as e:
30
- print(f"⚠️ GPU load failed, using CPU. Error: {e}")
31
- model = AutoModelForCausalLM.from_pretrained(
32
- MODEL_ID,
33
- torch_dtype=torch.float32,
34
- low_cpu_mem_usage=True
35
- )
36
 
37
  pipe = pipeline(
38
  "text-generation",
39
  model=model,
40
- tokenizer=tokenizer,
41
- device=0 if torch.cuda.is_available() else -1
42
  )
43
 
44
- # -----------------------
45
  # Helper: split report into panels
46
- # -----------------------
47
  def split_report(text: str):
48
- text = text.strip()
49
  markers = ["5. Tabular", "📊 Tabular", "## 5"]
50
  idx = None
51
  for m in markers:
52
  pos = text.find(m)
53
  if pos != -1:
54
- if idx is None or pos < idx:
55
- idx = pos
56
  if idx is None:
57
- return text, ""
58
  return text[:idx].strip(), text[idx:].strip()
59
 
60
- # -----------------------
61
- # Main analysis function
62
- # -----------------------
63
- def analyze(
64
- albumin, creatinine, glucose, crp, mcv, rdw, alp,
65
- wbc, lymph, age, gender, height, weight
66
- ):
67
- try:
68
- age = int(age)
69
- except Exception:
70
- age = age
71
  try:
72
- height = float(height)
73
- weight = float(weight)
74
- bmi = round(weight / ((height / 100) ** 2), 2) if height > 0 else "N/A"
75
  except Exception:
76
  bmi = "N/A"
77
 
78
- # -----------------------
79
  # Strict System Prompt
80
- # -----------------------
81
  system_prompt = (
82
  "You are a professional AI Medical Assistant.\n"
83
  "You must ONLY analyze: 9 Levine biomarkers + Age + Height + Weight.\n"
84
- "Forbidden: Any extra labs (cholesterol, vitamin D, ferritin, ALT, AST, urine, hormones, genetics).\n"
85
  "If information is not derivable, state clearly: 'Not available from current biomarkers.'\n\n"
 
86
  "Biomarkers allowed:\n"
87
  "- Albumin\n- Creatinine\n- Glucose\n- C-reactive protein (CRP)\n"
88
  "- Mean Cell Volume (MCV)\n- Red Cell Distribution Width (RDW)\n"
89
  "- Alkaline Phosphatase (ALP)\n- White Blood Cell count (WBC)\n"
90
  "- Lymphocyte percentage\n\n"
 
91
  "Output format:\n"
92
  "1. Executive Summary\n"
93
  "2. System-Specific Analysis\n"
94
  "3. Personalized Action Plan\n"
95
  "4. Interaction Alerts\n"
96
- "5. Tabular Mapping (Markdown table with Biomarker | Value | Range | Status | AI-Inferred Insight)\n"
97
  "6. Enhanced AI Insights & Longitudinal Risk\n\n"
 
98
  "Style: Professional, concise, structured, client-friendly. "
99
  "No hallucinations. No extra biomarkers. No absolute longevity claims.\n"
100
  )
@@ -111,26 +87,30 @@ def analyze(
111
 
112
  prompt = system_prompt + "\n" + patient_input
113
 
114
- gen = pipe(
115
- prompt,
116
- max_new_tokens=1500,
117
- do_sample=True,
118
- temperature=0.25, # lower temp = more factual
119
- top_p=0.9,
120
- repetition_penalty=1.05,
121
- return_full_text=False
122
- )
 
 
 
 
 
123
 
124
- generated = gen[0].get("generated_text", "").strip()
125
  if not generated:
126
  return "⚠️ No valid response. Please try again.", ""
127
 
128
  left_md, right_md = split_report(generated)
129
  return left_md, right_md
130
 
131
- # -----------------------
132
- # Gradio App
133
- # -----------------------
134
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
  gr.Markdown("# 🏥 AI Medical Biomarker Dashboard")
136
 
@@ -173,4 +153,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
  )
174
 
175
  if __name__ == "__main__":
176
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
1
+ import torch
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import gradio as gr
4
  import os
 
5
 
6
+ # ----------------------------
7
+ # Model Config
8
+ # ----------------------------
9
+ MODEL_ID = "Muhammadidrees/my-biomed" # or "BioMistral/BioMistral-7B"
 
 
 
 
10
 
 
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
 
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_ID,
15
+ device_map="auto", # Ensures GPU usage on HF (L4)
16
+ torch_dtype=torch.float16 # FP16 for speed + memory efficiency
17
+ )
 
 
 
 
 
 
 
 
 
18
 
19
  pipe = pipeline(
20
  "text-generation",
21
  model=model,
22
+ tokenizer=tokenizer
 
23
  )
24
 
25
+ # ----------------------------
26
  # Helper: split report into panels
27
+ # ----------------------------
28
  def split_report(text: str):
 
29
  markers = ["5. Tabular", "📊 Tabular", "## 5"]
30
  idx = None
31
  for m in markers:
32
  pos = text.find(m)
33
  if pos != -1:
34
+ idx = pos if idx is None or pos < idx else idx
 
35
  if idx is None:
36
+ return text.strip(), ""
37
  return text[:idx].strip(), text[idx:].strip()
38
 
39
+ # ----------------------------
40
+ # Main Analysis Function
41
+ # ----------------------------
42
+ def analyze(albumin, creatinine, glucose, crp, mcv, rdw, alp,
43
+ wbc, lymph, age, gender, height, weight):
44
+
45
+ # Compute BMI
 
 
 
 
46
  try:
47
+ bmi = round(float(weight) / ((float(height) / 100) ** 2), 2)
 
 
48
  except Exception:
49
  bmi = "N/A"
50
 
51
+ # ----------------------------
52
  # Strict System Prompt
53
+ # ----------------------------
54
  system_prompt = (
55
  "You are a professional AI Medical Assistant.\n"
56
  "You must ONLY analyze: 9 Levine biomarkers + Age + Height + Weight.\n"
57
+ "Forbidden: Any extra labs (cholesterol, vitamin D, hormones, etc.).\n"
58
  "If information is not derivable, state clearly: 'Not available from current biomarkers.'\n\n"
59
+
60
  "Biomarkers allowed:\n"
61
  "- Albumin\n- Creatinine\n- Glucose\n- C-reactive protein (CRP)\n"
62
  "- Mean Cell Volume (MCV)\n- Red Cell Distribution Width (RDW)\n"
63
  "- Alkaline Phosphatase (ALP)\n- White Blood Cell count (WBC)\n"
64
  "- Lymphocyte percentage\n\n"
65
+
66
  "Output format:\n"
67
  "1. Executive Summary\n"
68
  "2. System-Specific Analysis\n"
69
  "3. Personalized Action Plan\n"
70
  "4. Interaction Alerts\n"
71
+ "5. Tabular Mapping (Markdown table with Biomarker | Value | Range | Status | Insight)\n"
72
  "6. Enhanced AI Insights & Longitudinal Risk\n\n"
73
+
74
  "Style: Professional, concise, structured, client-friendly. "
75
  "No hallucinations. No extra biomarkers. No absolute longevity claims.\n"
76
  )
 
87
 
88
  prompt = system_prompt + "\n" + patient_input
89
 
90
+ # ----------------------------
91
+ # Generate
92
+ # ----------------------------
93
+ try:
94
+ gen = pipe(
95
+ prompt,
96
+ max_new_tokens=1200,
97
+ temperature=0.25,
98
+ top_p=0.9,
99
+ return_full_text=False
100
+ )
101
+ generated = gen[0]["generated_text"].strip()
102
+ except Exception as e:
103
+ return f"❌ Error: {str(e)}", ""
104
 
 
105
  if not generated:
106
  return "⚠️ No valid response. Please try again.", ""
107
 
108
  left_md, right_md = split_report(generated)
109
  return left_md, right_md
110
 
111
+ # ----------------------------
112
+ # Gradio UI
113
+ # ----------------------------
114
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
115
  gr.Markdown("# 🏥 AI Medical Biomarker Dashboard")
116
 
 
153
  )
154
 
155
  if __name__ == "__main__":
156
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), show_error=True)