Malaria-ai / app.py
Dacthex's picture
Update app.py
5485ce0 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# -----------------------------
# Load CPU-friendly AI model
# -----------------------------
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# -----------------------------
# AI response function
# -----------------------------
def malaria_ai(age, gender, location, travel_endemic, travel_details,
symptoms, temperature, blood_pressure, heart_rate,
previous_malaria, medications, additional_notes, agent):
if not age and not symptoms:
return "<p style='color:red;'>Please provide at least age or symptoms for analysis.</p>"
symptoms_list = ", ".join(symptoms) if symptoms else "No symptoms reported"
patient_info = f"""
Patient Information:
- Age: {age}
- Gender: {gender or 'Not specified'}
- Location: {location or 'Not specified'}
- Recent travel to malaria-endemic areas: {"Yes" if travel_endemic else "No"}
- Travel details: {travel_details or 'None'}
- Symptoms: {symptoms_list}
- Temperature: {temperature}°C
- Blood Pressure: {blood_pressure or 'Not recorded'}
- Heart Rate: {heart_rate or 'Not recorded'}
- Previous malaria episodes: {"Yes" if previous_malaria else "No"}
- Medications/Allergies: {medications or 'None'}
- Additional Notes: {additional_notes or 'None'}
"""
# Agent-specific instructions
if agent.lower() == "diagnostic":
instruction = """
Task: Provide a detailed **diagnostic report** for malaria.
Include:
1. Risk assessment based on symptoms and travel history
2. Suggested diagnostic tests (blood smear, rapid test, PCR)
3. Differential diagnoses
4. Severity classification if malaria is suspected
5. Red flags or warning signs to monitor
Format the response with bullet points and headings.
"""
header_color = "#2563eb" # blue
elif agent.lower() == "treatment":
instruction = """
Task: Provide a detailed **treatment recommendation**.
Include:
1. First-line treatment options based on suspected malaria type and severity
2. Dosage guidance based on age/weight
3. Alternative treatments for drug-resistant strains
4. Supportive care (hydration, fever management)
5. Monitoring and follow-up instructions
Format clearly with bullet points and headings.
"""
header_color = "#16a34a" # green
elif agent.lower() == "prognostic":
instruction = """
Task: Provide a detailed **prognostic report**.
Include:
1. Expected clinical course and recovery timeline
2. Risk factors for severe complications
3. Recommended follow-up schedule
4. Preventive measures for future malaria episodes
Format clearly with bullet points and headings.
"""
header_color = "#f97316" # orange
else:
instruction = ""
header_color = "#6b7280"
prompt = patient_info + instruction + "\nNote: For educational purposes only. Consult a healthcare professional."
# Generate AI response
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = model.generate(
**inputs,
max_new_tokens=700,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response_html = response.replace("\n", "<br>")
# -----------------------------
# Fixed AI response card styling
# -----------------------------
formatted_response = f"""
<div style="
border-radius:12px;
overflow:hidden;
box-shadow:0 5px 15px rgba(0,0,0,0.1);
font-family:sans-serif;
">
<div style="
background-color:{header_color};
color:white;
font-weight:bold;
padding:10px 15px;
font-size:16px;
">
{agent} Analysis
</div>
<div style="
background-color:#f0f9ff; /* light blue */
color:#111; /* dark text */
padding:15px;
max-height:400px;
overflow-y:auto;
">
{response_html}
</div>
</div>
"""
return formatted_response
# -----------------------------
# Gradio dashboard interface
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🦟 Malaria AI Assistant – Dashboard Style\nDiagnostic, treatment, and prognostic analysis")
with gr.Row():
with gr.Column(scale=1):
# Patient info sections
gr.Markdown("### 🧾 Demographics")
age = gr.Number(label="Age", value=25)
gender = gr.Dropdown(["", "Male", "Female", "Other"], label="Gender", value="Male")
location = gr.Textbox(label="Location", value="Lagos, Nigeria")
gr.Markdown("### 🌍 Travel History")
travel_endemic = gr.Checkbox(label="Recent travel to malaria-endemic areas", value=True)
travel_details = gr.Textbox(label="Travel Details", value="Visited rural Northern Nigeria for 2 weeks")
gr.Markdown("### 🤒 Symptoms")
symptoms = gr.CheckboxGroup(
["Fever","Chills","Headache","Nausea/Vomiting","Muscle aches","Fatigue"],
label="Symptoms",
value=["Fever","Chills","Headache"]
)
gr.Markdown("### ❤️ Vital Signs")
temperature = gr.Number(label="Temperature (°C)", value=38.5)
blood_pressure = gr.Textbox(label="Blood Pressure", value="120/80")
heart_rate = gr.Number(label="Heart Rate (bpm)", value=88)
gr.Markdown("### 🏥 Medical History")
previous_malaria = gr.Checkbox(label="Previous malaria episodes", value=True)
medications = gr.Textbox(label="Medications/Allergies", value="None")
gr.Markdown("### 📝 Additional Notes")
additional_notes = gr.Textbox(label="Additional Information", value="Patient shows early signs of fatigue.")
agent = gr.Radio(["Diagnostic", "Treatment", "Prognostic"], label="AI Analysis Type", value="Diagnostic")
submit_btn = gr.Button("Run Analysis")
with gr.Column(scale=1):
output = gr.HTML(label="AI Analysis Result")
submit_btn.click(
fn=malaria_ai,
inputs=[age, gender, location, travel_endemic, travel_details,
symptoms, temperature, blood_pressure, heart_rate,
previous_malaria, medications, additional_notes, agent],
outputs=output
)
demo.launch()