Dacthex commited on
Commit
b755401
·
verified ·
1 Parent(s): df31400

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -94
app.py CHANGED
@@ -1,105 +1,113 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
 
4
- # --- MODEL SETUP ---
5
- MODEL_ID = "TheBloke/guanaco-3B-GPTQ" # public, CPU-friendly
6
- print("Loading model (this may take a minute on CPU)...")
 
 
 
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- MODEL_ID,
11
- device_map="auto", # CPU only
12
- torch_dtype="auto",
13
- low_cpu_mem_usage=True
14
- )
 
 
 
 
 
 
 
 
 
15
 
16
- generator = pipeline(
17
- task="text-generation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- max_length=512
21
- )
 
 
 
 
 
 
22
 
23
- # --- HELPER FUNCTIONS ---
24
- def format_patient_data(data):
25
- symptoms = ", ".join(data["symptoms"]) if data["symptoms"] else "No symptoms reported"
26
- text = f"""
27
- Patient Data:
28
-
29
- - Age: {data.get('age', 'Not specified')}
30
- - Gender: {data.get('gender', 'Not specified')}
31
- - Location: {data.get('location', 'Not specified')}
32
- - Travel to endemic areas: {"Yes" if data.get("travel") else "No"}
33
- - Travel details: {data.get('travel_details', '')}
34
-
35
- Symptoms: {symptoms}
36
-
37
- Vital Signs:
38
- - Temperature: {data.get('temperature', 'Not recorded')}
39
- - Blood Pressure: {data.get('blood_pressure', 'Not recorded')}
40
- - Heart Rate: {data.get('heart_rate', 'Not recorded')}
41
-
42
- Medical History:
43
- - Previous malaria: {"Yes" if data.get("previous_malaria") else "No"}
44
- - Medications/allergies: {data.get('medications', '')}
45
-
46
- Additional notes: {data.get('additional_notes', '')}
47
  """
48
- return text.strip()
49
 
50
- def run_analysis(agent, patient_data):
51
- prompt = format_patient_data(patient_data)
52
- prompt += f"\n\n{agent.upper()} ANALYSIS REQUEST:\nProvide evidence-based medical information for educational purposes."
 
 
 
 
 
 
 
53
 
54
- response = generator(prompt, max_length=512, do_sample=True, temperature=0.7)[0]['generated_text']
55
- return response
56
 
57
- # --- GRADIO UI ---
58
- def build_ui():
59
- with gr.Blocks() as demo:
60
- gr.Markdown("# 🦟 Malaria AI Assistant (CPU-friendly, Open-Source)")
61
- with gr.Row():
62
- with gr.Column():
63
- age = gr.Number(label="Age", value=None)
64
- gender = gr.Dropdown(["Male", "Female", "Other"], label="Gender")
65
- location = gr.Textbox(label="Location")
66
- travel = gr.Checkbox(label="Traveled to endemic area?")
67
- travel_details = gr.Textbox(label="Travel Details")
68
- symptoms = gr.CheckboxGroup(
69
- ["Fever", "Chills", "Headache", "Nausea/Vomiting", "Muscle aches", "Fatigue"],
70
- label="Symptoms"
71
- )
72
- temperature = gr.Number(label="Temperature (°C)")
73
- blood_pressure = gr.Textbox(label="Blood Pressure")
74
- heart_rate = gr.Number(label="Heart Rate")
75
- previous_malaria = gr.Checkbox(label="Previous malaria episodes")
76
- medications = gr.Textbox(label="Medications / Allergies")
77
- additional_notes = gr.Textbox(label="Additional notes")
78
- with gr.Column():
79
- agent = gr.Radio(["diagnostic", "treatment", "prognosis"], label="Analysis Type", value="diagnostic")
80
- output = gr.Textbox(label="Analysis Result", interactive=False)
81
 
82
- btn = gr.Button("Run Analysis")
83
- btn.click(
84
- run_analysis,
85
- inputs=[agent, {
86
- "age": age,
87
- "gender": gender,
88
- "location": location,
89
- "travel": travel,
90
- "travel_details": travel_details,
91
- "symptoms": symptoms,
92
- "temperature": temperature,
93
- "blood_pressure": blood_pressure,
94
- "heart_rate": heart_rate,
95
- "previous_malaria": previous_malaria,
96
- "medications": medications,
97
- "additional_notes": additional_notes
98
- }],
99
- outputs=output
100
- )
101
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- if __name__ == "__main__":
104
- demo = build_ui()
105
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # -----------------------------
5
+ # Load stronger CPU-friendly AI model
6
+ # -----------------------------
7
+ model_name = "google/flan-t5-base" # upgraded from small
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
+ # -----------------------------
12
+ # AI response function
13
+ # -----------------------------
14
+ def malaria_ai(age, gender, location, travel_endemic, travel_details,
15
+ symptoms, temperature, blood_pressure, heart_rate,
16
+ previous_malaria, medications, additional_notes, agent):
17
+
18
+ # Basic input check
19
+ if not age and not symptoms:
20
+ return "<p style='color:red;'>Please provide at least age or symptoms for analysis.</p>"
21
+
22
+ symptoms_list = ", ".join(symptoms) if symptoms else "No symptoms reported"
23
+
24
+ # Structured prompt with clear instructions
25
+ prompt = f"""
26
+ Patient Information:
27
 
28
+ - Age: {age}
29
+ - Gender: {gender or 'Not specified'}
30
+ - Location: {location or 'Not specified'}
31
+ - Recent travel to malaria-endemic areas: {"Yes" if travel_endemic else "No"}
32
+ - Travel details: {travel_details or 'None'}
33
+ - Symptoms: {symptoms_list}
34
+ - Temperature: {temperature}°C
35
+ - Blood Pressure: {blood_pressure or 'Not recorded'}
36
+ - Heart Rate: {heart_rate or 'Not recorded'}
37
+ - Previous malaria episodes: {"Yes" if previous_malaria else "No"}
38
+ - Medications/Allergies: {medications or 'None'}
39
+ - Additional Notes: {additional_notes or 'None'}
40
 
41
+ Task:
42
+ You are a medical assistant AI. Based on the patient information, provide a detailed and realistic {agent.lower()} analysis of malaria. Use evidence-based reasoning. Include any necessary precautions or warnings. Format your response clearly with bullet points and paragraphs. Keep it informative but concise.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
 
44
 
45
+ # Generate AI response
46
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=600,
50
+ do_sample=True,
51
+ top_p=0.9,
52
+ temperature=0.7
53
+ )
54
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
56
+ # Replace line breaks safely outside f-string
57
+ response_html = response.replace("\n", "<br>")
58
 
59
+ # Scrollable styled HTML card
60
+ formatted_response = f"""
61
+ <div style="
62
+ background-color:#dbeafe;
63
+ padding:20px;
64
+ border-radius:10px;
65
+ box-shadow: 0 5px 15px rgba(0,0,0,0.1);
66
+ max-height:400px;
67
+ overflow-y:auto;
68
+ font-family: sans-serif;
69
+ ">
70
+ <p style="color:#111;">{response_html}</p>
71
+ </div>
72
+ """
73
+ return formatted_response
 
 
 
 
 
 
 
 
 
74
 
75
+ # -----------------------------
76
+ # Gradio interface with default demo data
77
+ # -----------------------------
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown("## 🦟 Malaria AI Assistant\nDiagnostic, treatment, and prognostic analysis")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ age = gr.Number(label="Age", value=25)
84
+ gender = gr.Dropdown(["", "Male", "Female", "Other"], label="Gender", value="Male")
85
+ location = gr.Textbox(label="Location (City, Country)", value="Lagos, Nigeria")
86
+ travel_endemic = gr.Checkbox(label="Recent travel to malaria-endemic areas", value=True)
87
+ travel_details = gr.Textbox(label="Travel Details", value="Visited rural Northern Nigeria for 2 weeks")
88
+ symptoms = gr.CheckboxGroup(
89
+ ["Fever","Chills","Headache","Nausea/Vomiting","Muscle aches","Fatigue"],
90
+ label="Symptoms",
91
+ value=["Fever","Chills","Headache"]
92
+ )
93
+ temperature = gr.Number(label="Temperature (°C)", value=38.5)
94
+ blood_pressure = gr.Textbox(label="Blood Pressure", value="120/80")
95
+ heart_rate = gr.Number(label="Heart Rate (bpm)", value=88)
96
+ previous_malaria = gr.Checkbox(label="Previous malaria episodes", value=True)
97
+ medications = gr.Textbox(label="Current medications/allergies", value="None")
98
+ additional_notes = gr.Textbox(label="Additional information", value="Patient shows early signs of fatigue.")
99
+ agent = gr.Radio(["Diagnostic", "Treatment", "Prognostic"], label="AI Analysis Type", value="Diagnostic")
100
+ submit_btn = gr.Button("Run Analysis")
101
+
102
+ with gr.Column():
103
+ output = gr.HTML(label="AI Analysis Result")
104
+
105
+ submit_btn.click(
106
+ fn=malaria_ai,
107
+ inputs=[age, gender, location, travel_endemic, travel_details,
108
+ symptoms, temperature, blood_pressure, heart_rate,
109
+ previous_malaria, medications, additional_notes, agent],
110
+ outputs=output
111
+ )
112
 
113
+ demo.launch()