File size: 4,290 Bytes
56020bd
 
 
8d48081
56020bd
8d48081
 
 
 
 
 
 
56020bd
 
 
 
 
 
 
 
8d48081
 
56020bd
 
 
8d48081
 
56020bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
import pandas as pd
import os
import sys
from typing import List, Dict

# Add parent directory to path for utils imports
current_file = __file__
parent_dir = os.path.dirname(current_file)
project_root = os.path.dirname(parent_dir)
sys.path.insert(0, project_root)

from utils.embeddings import get_chroma
from utils.rules import apply_rules

st.set_page_config(page_title="Epilepsy AI Agent", page_icon="🧠", layout="wide")

st.title("🧠 Epilepsy Management AI Agent")
st.caption("Personalized, evidence-informed recommendations for epilepsy management devices and techniques.")

# Load catalog - use project root for data paths
CATALOG_PATH = os.path.join(project_root, "data", "devices.csv")
df = pd.read_csv(CATALOG_PATH)

# Build/attach vector index on first run
chroma_dir = os.path.join(project_root, ".chroma")
client, col = get_chroma(collection_name="devices", persist_dir=chroma_dir)

def ensure_index(df: pd.DataFrame):
    existing = set(col.get(ids=None)["ids"]) if col.count() else set()
    new_rows = []
    for _, row in df.iterrows():
        rid = str(row["device_id"])
        if rid not in existing:
            new_rows.append(row)
    if new_rows:
        ids = [str(r["device_id"]) for r in new_rows]
        docs = [f"{r['name']} | {r['category']} | {r['indication']} | {r['invasiveness']} | {r['approvals']} | {r['summary']}" for r in new_rows]
        metadatas = [r.to_dict() for r in new_rows]
        col.add(ids=ids, documents=docs, metadatas=metadatas)

ensure_index(df)

with st.sidebar:
    st.header("Filters & Safety Rules")
    avoid_invasive = st.toggle("Avoid invasive devices", value=True)
    req_apps = st.text_input("Require approvals (comma separated)", value="")
    top_k = st.slider("Top K", 1, 10, 3)
    st.markdown("---")
    st.subheader("Catalog Snapshot")
    st.write(df[["device_id","name","category","invasiveness","approvals"]])

st.subheader("Patient Input")
patient_text = st.text_area(
    "Describe the patient (seizure type, frequency, age group, prior treatments, lifestyle).",    height=140,    placeholder="Example: 16-year-old with focal aware seizures 2–3 times/week, mostly nocturnal; wants non-invasive detection and caregiver alerts."
)

if st.button("Generate Recommendations", type="primary"):
    if not patient_text.strip():
        st.warning("Please enter a short patient description first.")
    else:
        q = patient_text.strip()
        # Retrieve candidates
        res = col.query(query_texts=[q], n_results=10)
        candidates: List[Dict] = res.get("metadatas", [[]])[0]
        # Apply safety/business rules
        filtered = apply_rules(candidates, require_approvals=req_apps, avoid_invasive=avoid_invasive)
        # Rank again by distance (already sorted by similarity) and limit
        filtered_ids = [c.get("device_id") for c in filtered][:top_k]
        if not filtered_ids:
            st.info("No items matched your filters. Try clearing approvals, or switch off 'Avoid invasive' when looking for implants.")
        else:
            st.success(f"Showing {len(filtered_ids)} recommendation(s)." )
            show = df[df["device_id"].isin(filtered_ids)].copy()
            # Merge in a simple "why" snippet
            why_map = {}
            for i, md in enumerate(res.get("metadatas", [[]])[0]):
                why_map[str(md.get("device_id"))] = md.get("summary", "")
            show["why"] = show["device_id"].astype(str).map(why_map).fillna("")
            # Display
            for _, r in show.iterrows():
                with st.container(border=True):
                    st.markdown(f"### {r['name']}")
                    c1, c2, c3, c4 = st.columns([2,1,1,2])
                    c1.write(f"**Category:** {r['category']}") 
                    c2.write(f"**Invasiveness:** {r['invasiveness']}")
                    c3.write(f"**Approvals:** {r['approvals']}")
                    c4.write(f"**Indication:** {r['indication']}")
                    st.write(r["why"])
                    st.link_button("Open Source Page", r["source_url"], use_container_width=False)

st.markdown("---")
st.info("This tool provides informational recommendations only and is not a substitute for professional medical advice. Always consult a licensed clinician.")