import streamlit as st import pandas as pd import os import sys from typing import List, Dict # Add parent directory to path for utils imports current_file = __file__ parent_dir = os.path.dirname(current_file) project_root = os.path.dirname(parent_dir) sys.path.insert(0, project_root) from utils.embeddings import get_chroma from utils.rules import apply_rules st.set_page_config(page_title="Epilepsy AI Agent", page_icon="🧠", layout="wide") st.title("🧠 Epilepsy Management AI Agent") st.caption("Personalized, evidence-informed recommendations for epilepsy management devices and techniques.") # Load catalog - use project root for data paths CATALOG_PATH = os.path.join(project_root, "data", "devices.csv") df = pd.read_csv(CATALOG_PATH) # Build/attach vector index on first run chroma_dir = os.path.join(project_root, ".chroma") client, col = get_chroma(collection_name="devices", persist_dir=chroma_dir) def ensure_index(df: pd.DataFrame): existing = set(col.get(ids=None)["ids"]) if col.count() else set() new_rows = [] for _, row in df.iterrows(): rid = str(row["device_id"]) if rid not in existing: new_rows.append(row) if new_rows: ids = [str(r["device_id"]) for r in new_rows] docs = [f"{r['name']} | {r['category']} | {r['indication']} | {r['invasiveness']} | {r['approvals']} | {r['summary']}" for r in new_rows] metadatas = [r.to_dict() for r in new_rows] col.add(ids=ids, documents=docs, metadatas=metadatas) ensure_index(df) with st.sidebar: st.header("Filters & Safety Rules") avoid_invasive = st.toggle("Avoid invasive devices", value=True) req_apps = st.text_input("Require approvals (comma separated)", value="") top_k = st.slider("Top K", 1, 10, 3) st.markdown("---") st.subheader("Catalog Snapshot") st.write(df[["device_id","name","category","invasiveness","approvals"]]) st.subheader("Patient Input") patient_text = st.text_area( "Describe the patient (seizure type, frequency, age group, prior treatments, lifestyle).", height=140, placeholder="Example: 16-year-old with focal aware seizures 2–3 times/week, mostly nocturnal; wants non-invasive detection and caregiver alerts." ) if st.button("Generate Recommendations", type="primary"): if not patient_text.strip(): st.warning("Please enter a short patient description first.") else: q = patient_text.strip() # Retrieve candidates res = col.query(query_texts=[q], n_results=10) candidates: List[Dict] = res.get("metadatas", [[]])[0] # Apply safety/business rules filtered = apply_rules(candidates, require_approvals=req_apps, avoid_invasive=avoid_invasive) # Rank again by distance (already sorted by similarity) and limit filtered_ids = [c.get("device_id") for c in filtered][:top_k] if not filtered_ids: st.info("No items matched your filters. Try clearing approvals, or switch off 'Avoid invasive' when looking for implants.") else: st.success(f"Showing {len(filtered_ids)} recommendation(s)." ) show = df[df["device_id"].isin(filtered_ids)].copy() # Merge in a simple "why" snippet why_map = {} for i, md in enumerate(res.get("metadatas", [[]])[0]): why_map[str(md.get("device_id"))] = md.get("summary", "") show["why"] = show["device_id"].astype(str).map(why_map).fillna("") # Display for _, r in show.iterrows(): with st.container(border=True): st.markdown(f"### {r['name']}") c1, c2, c3, c4 = st.columns([2,1,1,2]) c1.write(f"**Category:** {r['category']}") c2.write(f"**Invasiveness:** {r['invasiveness']}") c3.write(f"**Approvals:** {r['approvals']}") c4.write(f"**Indication:** {r['indication']}") st.write(r["why"]) st.link_button("Open Source Page", r["source_url"], use_container_width=False) st.markdown("---") st.info("This tool provides informational recommendations only and is not a substitute for professional medical advice. Always consult a licensed clinician.")