Dacthex commited on
Commit
e7c3898
·
verified ·
1 Parent(s): 7e4e44c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -42
app.py CHANGED
@@ -1,50 +1,88 @@
1
- import time
2
- from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
3
  import gradio as gr
 
 
4
 
5
- # Model
6
- model_name = "emilyalsentzer/Bio_ClinicalBERT"
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForMaskedLM.from_pretrained(model_name)
9
-
10
- # Pipeline
11
- nlp_pipeline = pipeline("fill-mask", model=model, tokenizer=tokenizer)
12
-
13
- def analyze_malaria(text):
14
- yield "Running analysis..." # Status message
15
- time.sleep(0.5)
16
-
17
- # Diagnosis agent
18
- try:
19
- diagnosis_results = nlp_pipeline(f"{text} The patient may have [MASK].")
20
- diagnosis = ", ".join([res['sequence'] for res in diagnosis_results[:3]])
21
- except:
22
- diagnosis = "Diagnosis analysis failed."
23
 
24
- # Treatment agent
25
- try:
26
- treatment_results = nlp_pipeline(f"For malaria, recommended treatment is [MASK].")
27
- treatment = ", ".join([res['sequence'] for res in treatment_results[:3]])
28
- except:
29
- treatment = "Treatment analysis failed."
30
 
31
- # Prognosis agent
32
- try:
33
- prognosis_results = nlp_pipeline(f"Prognosis for malaria patient is [MASK].")
34
- prognosis = ", ".join([res['sequence'] for res in prognosis_results[:3]])
35
- except:
36
- prognosis = "Prognosis analysis failed."
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- yield f"Diagnosis:\n{diagnosis}\n\nTreatment:\n{treatment}\n\nPrognosis:\n{prognosis}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Gradio interface
41
- iface = gr.Interface(
42
- fn=analyze_malaria,
43
- inputs=gr.Textbox(lines=4, placeholder="Enter symptoms here..."),
44
- outputs=gr.Textbox(label="Malaria Analysis"),
45
- title="Malaria Multi-Agent AI",
46
- description="Enter clinical symptoms. The AI provides diagnosis, treatment, and prognosis using Bio_ClinicalBERT.",
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- if __name__ == "__main__":
50
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import time
4
 
5
+ # --- MODEL SETUP ---
6
+ # CPU-friendly LLM
7
+ model_name = "google/flan-t5-small" # small enough for CPU
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
12
 
13
+ # --- MALARIA RULES ---
14
+ malaria_rules = {
15
+ "fever": {"diagnosis": "Malaria likely", "treatment": "ACT (artemisinin-based combination therapy)", "prognosis": "Good if treated early"},
16
+ "chills": {"diagnosis": "Malaria likely", "treatment": "ACT", "prognosis": "Good if treated early"},
17
+ "headache": {"diagnosis": "Malaria possible", "treatment": "Supportive care and ACT if confirmed", "prognosis": "Good with treatment"},
18
+ "sweating": {"diagnosis": "Malaria possible", "treatment": "Monitor temperature, ACT if confirmed", "prognosis": "Good with treatment"},
19
+ "nausea": {"diagnosis": "Malaria possible", "treatment": "Hydration, ACT if confirmed", "prognosis": "Good with treatment"},
20
+ "vomiting": {"diagnosis": "Malaria possible", "treatment": "Hydration, ACT if confirmed", "prognosis": "Good with treatment"},
21
+ "fatigue": {"diagnosis": "Malaria possible", "treatment": "Rest and ACT if confirmed", "prognosis": "Good with treatment"},
22
+ "anemia": {"diagnosis": "Severe malaria possible", "treatment": "Hospitalization, blood transfusion, ACT", "prognosis": "Guarded, monitor closely"},
23
+ "jaundice": {"diagnosis": "Severe malaria possible", "treatment": "Hospitalization, supportive care, ACT", "prognosis": "Guarded"},
24
+ "convulsions": {"diagnosis": "Cerebral malaria possible", "treatment": "Emergency care, IV antimalarials", "prognosis": "Poor if untreated"},
25
+ "respiratory distress": {"diagnosis": "Severe malaria possible", "treatment": "Oxygen therapy, IV antimalarials", "prognosis": "Guarded"},
26
+ "abdominal pain": {"diagnosis": "Malaria possible", "treatment": "Supportive care, ACT if confirmed", "prognosis": "Good with treatment"},
27
+ "diarrhea": {"diagnosis": "Malaria possible", "treatment": "Hydration, ACT if confirmed", "prognosis": "Good with treatment"},
28
+ "muscle pain": {"diagnosis": "Malaria possible", "treatment": "Rest, analgesics, ACT if confirmed", "prognosis": "Good with treatment"},
29
+ }
30
 
31
+ # --- ANALYSIS FUNCTION ---
32
+ def analyze_symptoms(symptoms_input):
33
+ # show temporary running feedback
34
+ output = "Running analysis...\n"
35
+
36
+ symptoms = [s.strip().lower() for s in symptoms_input.split(",")]
37
+
38
+ diagnosis_list = []
39
+ treatment_list = []
40
+ prognosis_list = []
41
+
42
+ for symptom in symptoms:
43
+ if symptom in malaria_rules:
44
+ rule = malaria_rules[symptom]
45
+ diagnosis_list.append(rule["diagnosis"])
46
+ treatment_list.append(rule["treatment"])
47
+ prognosis_list.append(rule["prognosis"])
48
+ else:
49
+ diagnosis_list.append(f"No rule for '{symptom}'")
50
+ treatment_list.append(f"No rule for '{symptom}'")
51
+ prognosis_list.append(f"No rule for '{symptom}'")
52
+
53
+ # Convert lists to readable text
54
+ diagnosis_text = "\n".join(diagnosis_list)
55
+ treatment_text = "\n".join(treatment_list)
56
+ prognosis_text = "\n".join(prognosis_list)
57
+
58
+ # Enhance outputs with LLM
59
+ enhanced_diagnosis = generator(f"Summarize and explain clearly: {diagnosis_text}", max_length=150)[0]["generated_text"]
60
+ enhanced_treatment = generator(f"Summarize and explain clearly: {treatment_text}", max_length=150)[0]["generated_text"]
61
+ enhanced_prognosis = generator(f"Summarize and explain clearly: {prognosis_text}", max_length=150)[0]["generated_text"]
62
+
63
+ return enhanced_diagnosis, enhanced_treatment, enhanced_prognosis
64
 
65
+ # --- GRADIO UI ---
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("## Malaria Multi-Agent AI")
68
+
69
+ with gr.Row():
70
+ symptoms_input = gr.Textbox(label="Enter symptoms (comma-separated)", placeholder="fever, chills, headache")
71
+ analyze_button = gr.Button("Run Analysis")
72
+
73
+ with gr.Row():
74
+ diagnosis_output = gr.Textbox(label="Diagnosis", interactive=False)
75
+ treatment_output = gr.Textbox(label="Treatment", interactive=False)
76
+ prognosis_output = gr.Textbox(label="Prognosis", interactive=False)
77
+
78
+ def on_click(symptoms):
79
+ # Show temporary running message first
80
+ diagnosis_output.value = "Running analysis..."
81
+ treatment_output.value = "Running analysis..."
82
+ prognosis_output.value = "Running analysis..."
83
+ time.sleep(0.5) # short pause to simulate processing
84
+ return analyze_symptoms(symptoms)
85
+
86
+ analyze_button.click(on_click, inputs=[symptoms_input], outputs=[diagnosis_output, treatment_output, prognosis_output])
87
 
88
+ demo.launch()