| import streamlit as st |
| import pandas as pd |
| import os |
| import sys |
| from typing import List, Dict |
|
|
| |
| current_file = __file__ |
| parent_dir = os.path.dirname(current_file) |
| project_root = os.path.dirname(parent_dir) |
| sys.path.insert(0, project_root) |
|
|
| from utils.embeddings import get_chroma |
| from utils.rules import apply_rules |
|
|
| st.set_page_config(page_title="Epilepsy AI Agent", page_icon="🧠", layout="wide") |
|
|
| st.title("🧠 Epilepsy Management AI Agent") |
| st.caption("Personalized, evidence-informed recommendations for epilepsy management devices and techniques.") |
|
|
| |
| CATALOG_PATH = os.path.join(project_root, "data", "devices.csv") |
| df = pd.read_csv(CATALOG_PATH) |
|
|
| |
| chroma_dir = os.path.join(project_root, ".chroma") |
| client, col = get_chroma(collection_name="devices", persist_dir=chroma_dir) |
|
|
| def ensure_index(df: pd.DataFrame): |
| existing = set(col.get(ids=None)["ids"]) if col.count() else set() |
| new_rows = [] |
| for _, row in df.iterrows(): |
| rid = str(row["device_id"]) |
| if rid not in existing: |
| new_rows.append(row) |
| if new_rows: |
| ids = [str(r["device_id"]) for r in new_rows] |
| docs = [f"{r['name']} | {r['category']} | {r['indication']} | {r['invasiveness']} | {r['approvals']} | {r['summary']}" for r in new_rows] |
| metadatas = [r.to_dict() for r in new_rows] |
| col.add(ids=ids, documents=docs, metadatas=metadatas) |
|
|
| ensure_index(df) |
|
|
| with st.sidebar: |
| st.header("Filters & Safety Rules") |
| avoid_invasive = st.toggle("Avoid invasive devices", value=True) |
| req_apps = st.text_input("Require approvals (comma separated)", value="") |
| top_k = st.slider("Top K", 1, 10, 3) |
| st.markdown("---") |
| st.subheader("Catalog Snapshot") |
| st.write(df[["device_id","name","category","invasiveness","approvals"]]) |
|
|
| st.subheader("Patient Input") |
| patient_text = st.text_area( |
| "Describe the patient (seizure type, frequency, age group, prior treatments, lifestyle).", height=140, placeholder="Example: 16-year-old with focal aware seizures 2–3 times/week, mostly nocturnal; wants non-invasive detection and caregiver alerts." |
| ) |
|
|
| if st.button("Generate Recommendations", type="primary"): |
| if not patient_text.strip(): |
| st.warning("Please enter a short patient description first.") |
| else: |
| q = patient_text.strip() |
| |
| res = col.query(query_texts=[q], n_results=10) |
| candidates: List[Dict] = res.get("metadatas", [[]])[0] |
| |
| filtered = apply_rules(candidates, require_approvals=req_apps, avoid_invasive=avoid_invasive) |
| |
| filtered_ids = [c.get("device_id") for c in filtered][:top_k] |
| if not filtered_ids: |
| st.info("No items matched your filters. Try clearing approvals, or switch off 'Avoid invasive' when looking for implants.") |
| else: |
| st.success(f"Showing {len(filtered_ids)} recommendation(s)." ) |
| show = df[df["device_id"].isin(filtered_ids)].copy() |
| |
| why_map = {} |
| for i, md in enumerate(res.get("metadatas", [[]])[0]): |
| why_map[str(md.get("device_id"))] = md.get("summary", "") |
| show["why"] = show["device_id"].astype(str).map(why_map).fillna("") |
| |
| for _, r in show.iterrows(): |
| with st.container(border=True): |
| st.markdown(f"### {r['name']}") |
| c1, c2, c3, c4 = st.columns([2,1,1,2]) |
| c1.write(f"**Category:** {r['category']}") |
| c2.write(f"**Invasiveness:** {r['invasiveness']}") |
| c3.write(f"**Approvals:** {r['approvals']}") |
| c4.write(f"**Indication:** {r['indication']}") |
| st.write(r["why"]) |
| st.link_button("Open Source Page", r["source_url"], use_container_width=False) |
|
|
| st.markdown("---") |
| st.info("This tool provides informational recommendations only and is not a substitute for professional medical advice. Always consult a licensed clinician.") |
|
|