Suraj Prasai
commited on
Commit
·
458c8e2
0
Parent(s):
aded initial
Browse files- .gitattributes +35 -0
- .gitignore +1 -0
- __init__.py +0 -0
- __pycache__/app.cpython-311.pyc +0 -0
- __pycache__/test_app.cpython-311.pyc +0 -0
- agent/ChatbotHandler.py +189 -0
- agent/__init__.py +0 -0
- agent/__pycache__/ChatbotHandler.cpython-311.pyc +0 -0
- agent/__pycache__/__init__.cpython-311.pyc +0 -0
- agent/__pycache__/agent_graph.cpython-311.pyc +0 -0
- agent/__pycache__/runtime_ctx.cpython-311.pyc +0 -0
- agent/__pycache__/tools.cpython-311.pyc +0 -0
- agent/agent_graph.py +215 -0
- agent/runtime_ctx.py +175 -0
- agent/simple_chat.py +103 -0
- agent/tools.py +354 -0
- app.py +532 -0
- app_studio.py +543 -0
- app_test.py +867 -0
- backend_client.py +42 -0
- data/Aubrie.csv +0 -0
- data/JS00001_filtered.csv +0 -0
- data/Lisette.csv +0 -0
- ecg_analyzer.py +198 -0
- ecg_visualization.py +67 -0
- examples/main.py +27 -0
- pipeline/__init__.py +0 -0
- pipeline/__pycache__/__init__.cpython-311.pyc +0 -0
- pipeline/__pycache__/deduplication.cpython-311.pyc +0 -0
- pipeline/__pycache__/featurizer.cpython-311.pyc +0 -0
- pipeline/__pycache__/issues.cpython-311.pyc +0 -0
- pipeline/__pycache__/pipeline.cpython-311.pyc +0 -0
- pipeline/__pycache__/utils.cpython-311.pyc +0 -0
- pipeline/__pycache__/utils_cool.cpython-311.pyc +0 -0
- pipeline/deduplication.py +185 -0
- pipeline/featurizer.py +358 -0
- pipeline/issues.py +157 -0
- pipeline/pipeline.py +63 -0
- pipeline/utils_cool.py +208 -0
- requirements.txt +126 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
__init__.py
ADDED
|
File without changes
|
__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
__pycache__/test_app.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
agent/ChatbotHandler.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from langchain_core.messages import HumanMessage
|
| 6 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 7 |
+
|
| 8 |
+
from agent.agent_graph import build_app
|
| 9 |
+
from pipeline.utils_cool import df_to_payload, parse_user_choice
|
| 10 |
+
|
| 11 |
+
from .runtime_ctx import get_df_summary
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
class ChatbotHandler:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.ctx: Dict[str, Any] = {
|
| 18 |
+
"graph_app": None, # LangGraph app
|
| 19 |
+
"state": { # mirrors your prior STATE
|
| 20 |
+
"df_payload": None,
|
| 21 |
+
"results": [],
|
| 22 |
+
"steps_taken": 0,
|
| 23 |
+
"confirmed_step": None,
|
| 24 |
+
"confirmed_params": {},
|
| 25 |
+
"last_task": None,
|
| 26 |
+
"plan": None,
|
| 27 |
+
"messages": [],
|
| 28 |
+
"max_steps": 8,
|
| 29 |
+
},
|
| 30 |
+
}
|
| 31 |
+
# keep chat UI history in the component; we only need to return a reply string to it
|
| 32 |
+
# but Gradio Chatbot expects (history, ""), so we'll append our reply to history.
|
| 33 |
+
self._boot_text: Optional[str] = None # first reply after upload
|
| 34 |
+
|
| 35 |
+
# LLM for the graph
|
| 36 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 37 |
+
model="gemini-2.5-flash-lite",
|
| 38 |
+
temperature=0,
|
| 39 |
+
api_key=os.getenv("GOOGLE_API_KEY"),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def _format_summary(self, s: Dict[str, Any]) -> str:
|
| 43 |
+
cols = s.get("columns") or []
|
| 44 |
+
dtypes = s.get("dtypes") or {}
|
| 45 |
+
shape = s.get("shape") or (None, None)
|
| 46 |
+
label_guess = s.get("label_guess") or "None"
|
| 47 |
+
task_guess = s.get("task_guess") or "Unknown"
|
| 48 |
+
issues = s.get("issues") or []
|
| 49 |
+
|
| 50 |
+
# keep it concise but helpful
|
| 51 |
+
dt_pairs = [f"{k}: {v}" for k, v in list(dtypes.items())[:8]]
|
| 52 |
+
if len(dtypes) > 8:
|
| 53 |
+
dt_pairs.append("…")
|
| 54 |
+
|
| 55 |
+
lines = [
|
| 56 |
+
"### Dataset summary",
|
| 57 |
+
f"- Shape: {shape[0]} rows × {shape[1]} columns",
|
| 58 |
+
f"- Columns: {', '.join(map(str, cols[:10]))}{'…' if len(cols) > 10 else ''}",
|
| 59 |
+
f"- Dtypes: {', '.join(dt_pairs)}",
|
| 60 |
+
f"- Label guess: {label_guess}",
|
| 61 |
+
f"- Task guess: {task_guess}",
|
| 62 |
+
]
|
| 63 |
+
if issues:
|
| 64 |
+
lines.append(f"- Potential issues: {('; '.join(issues[:3]))}{'…' if len(issues) > 3 else ''}")
|
| 65 |
+
return "\n".join(lines)
|
| 66 |
+
|
| 67 |
+
# ---------------- Boot on upload: run inspect + SOTA + plan ----------------
|
| 68 |
+
def update_context(self, file_path: Optional[str], data_type: Optional[str], df: Optional["pd.DataFrame"]):
|
| 69 |
+
if df is None:
|
| 70 |
+
return ""
|
| 71 |
+
|
| 72 |
+
# (Re)build graph + seed state (unchanged)
|
| 73 |
+
self.ctx["graph_app"] = build_app(self.llm)
|
| 74 |
+
df_payload = df_to_payload(df)
|
| 75 |
+
st = self.ctx["state"]
|
| 76 |
+
st.update({
|
| 77 |
+
"df_payload": df_payload,
|
| 78 |
+
"results": [],
|
| 79 |
+
"steps_taken": 0,
|
| 80 |
+
"confirmed_step": None,
|
| 81 |
+
"confirmed_params": {},
|
| 82 |
+
"last_task": None,
|
| 83 |
+
"plan": None,
|
| 84 |
+
"messages": [HumanMessage(content="A new dataset was uploaded. Start the workflow.")],
|
| 85 |
+
"max_steps": 8,
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
final = self.ctx["graph_app"].invoke(st)
|
| 89 |
+
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
|
| 90 |
+
st[k] = final.get(k, st.get(k))
|
| 91 |
+
|
| 92 |
+
# Build the boot text from the stored summary (authoritative + consistent)
|
| 93 |
+
s = get_df_summary() or {}
|
| 94 |
+
summary_text = self._format_summary(s)
|
| 95 |
+
|
| 96 |
+
# Decide whether to ask for confirmation
|
| 97 |
+
task_guess = (s.get("task_guess") or "").lower()
|
| 98 |
+
label_guess = s.get("label_guess")
|
| 99 |
+
needs_task = task_guess not in {"classification", "regression", "unsupervised"}
|
| 100 |
+
needs_label = (task_guess in {"classification", "regression"}) and (not label_guess)
|
| 101 |
+
|
| 102 |
+
if needs_task or needs_label:
|
| 103 |
+
ask = "\n\nPlease confirm the task" + (" and label column" if needs_label else "") + \
|
| 104 |
+
". For example: `task=classification label=noisy_letter_grade`."
|
| 105 |
+
else:
|
| 106 |
+
ask = f"\n\nIf that looks right, say `confirm task={task_guess}" + \
|
| 107 |
+
(f" label={label_guess}`" if label_guess else "`") + \
|
| 108 |
+
" and I’ll fetch SOTA and propose a plan."
|
| 109 |
+
|
| 110 |
+
self._boot_text = summary_text + ask
|
| 111 |
+
return self._boot_text
|
| 112 |
+
|
| 113 |
+
# ---------------- One chat turn → graph turn ----------------
|
| 114 |
+
def respond(self, message: str, history: List):
|
| 115 |
+
if history is None:
|
| 116 |
+
history = []
|
| 117 |
+
msg = (message or "").strip()
|
| 118 |
+
if not msg:
|
| 119 |
+
return history, ""
|
| 120 |
+
|
| 121 |
+
# If we have a prepared boot reply (from upload) and the chat is empty,
|
| 122 |
+
# show it before processing the user's first message.
|
| 123 |
+
if self._boot_text and len(history) == 0:
|
| 124 |
+
history.append(("[system]", self._boot_text))
|
| 125 |
+
self._boot_text = None
|
| 126 |
+
|
| 127 |
+
# Require a booted graph
|
| 128 |
+
if self.ctx.get("graph_app") is None:
|
| 129 |
+
history.append((msg, "Please upload a dataset first."))
|
| 130 |
+
return history, ""
|
| 131 |
+
|
| 132 |
+
st = self.ctx["state"]
|
| 133 |
+
|
| 134 |
+
# Allow quick “run X a=b” parsing before we call the graph (same as your old handle_chat)
|
| 135 |
+
step, params = parse_user_choice(msg)
|
| 136 |
+
if step:
|
| 137 |
+
st["confirmed_step"] = step
|
| 138 |
+
st["confirmed_params"] = {**(st.get("confirmed_params") or {}), **params}
|
| 139 |
+
|
| 140 |
+
# Build this turn’s input state
|
| 141 |
+
messages = (st.get("messages") or []) + [HumanMessage(content=msg)]
|
| 142 |
+
turn_state = {
|
| 143 |
+
"messages": messages,
|
| 144 |
+
"df_payload": st.get("df_payload"),
|
| 145 |
+
"results": st.get("results", []),
|
| 146 |
+
"steps_taken": st.get("steps_taken", 0),
|
| 147 |
+
"max_steps": max(8, st.get("steps_taken", 0) + 4),
|
| 148 |
+
"confirmed_step": st.get("confirmed_step"),
|
| 149 |
+
"confirmed_params": st.get("confirmed_params", {}),
|
| 150 |
+
"last_task": st.get("last_task"),
|
| 151 |
+
"plan": st.get("plan"),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Invoke graph for this turn
|
| 155 |
+
final = self.ctx["graph_app"].invoke(turn_state)
|
| 156 |
+
|
| 157 |
+
# Persist state back
|
| 158 |
+
for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
|
| 159 |
+
st[k] = final.get(k, turn_state.get(k, st.get(k)))
|
| 160 |
+
|
| 161 |
+
# Extract assistant text
|
| 162 |
+
reply = self._extract_ai_text(final.get("messages", [])) or "Done."
|
| 163 |
+
history.append((msg, reply))
|
| 164 |
+
return history, ""
|
| 165 |
+
|
| 166 |
+
# ---------------- helper: extract last AI string ----------------
|
| 167 |
+
def _extract_ai_text(self, messages: List[Any]) -> str:
|
| 168 |
+
def coerce_text(content: Any) -> str:
|
| 169 |
+
if content is None: return ""
|
| 170 |
+
if isinstance(content, str): return content
|
| 171 |
+
if isinstance(content, list):
|
| 172 |
+
parts = []
|
| 173 |
+
for c in content:
|
| 174 |
+
if isinstance(c, dict):
|
| 175 |
+
parts.append(str(c.get("text") or c.get("content") or c.get("data") or ""))
|
| 176 |
+
else:
|
| 177 |
+
parts.append(str(c))
|
| 178 |
+
return " ".join(p for p in parts if p)
|
| 179 |
+
return str(content)
|
| 180 |
+
|
| 181 |
+
for m in reversed(messages or []):
|
| 182 |
+
role = getattr(m, "type", None) or getattr(m, "role", None)
|
| 183 |
+
if role in ("ai", "assistant", "aimessage"):
|
| 184 |
+
return coerce_text(getattr(m, "content", None))
|
| 185 |
+
if isinstance(m, dict):
|
| 186 |
+
r = (m.get("role") or m.get("type") or "").lower()
|
| 187 |
+
if r in ("assistant", "ai", "aimessage"):
|
| 188 |
+
return coerce_text(m.get("content"))
|
| 189 |
+
return ""
|
agent/__init__.py
ADDED
|
File without changes
|
agent/__pycache__/ChatbotHandler.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
agent/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (162 Bytes). View file
|
|
|
agent/__pycache__/agent_graph.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
agent/__pycache__/runtime_ctx.cpython-311.pyc
ADDED
|
Binary file (9.63 kB). View file
|
|
|
agent/__pycache__/tools.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
agent/agent_graph.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# agent/agent_graph.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from typing import Any, Dict, List, Optional, TypedDict
|
| 6 |
+
|
| 7 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
| 8 |
+
from langgraph.graph import END, START, StateGraph
|
| 9 |
+
from langgraph.prebuilt import ToolNode
|
| 10 |
+
|
| 11 |
+
from .runtime_ctx import set_df_payload, set_df_summary, set_sota_bundled
|
| 12 |
+
from .tools import (
|
| 13 |
+
tool_describe_step,
|
| 14 |
+
tool_inspect_dataset,
|
| 15 |
+
tool_list_steps,
|
| 16 |
+
tool_list_versions,
|
| 17 |
+
tool_propose_plan,
|
| 18 |
+
tool_reset_to_version,
|
| 19 |
+
tool_run_step,
|
| 20 |
+
tool_sota_preprocessing,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _to_text(content: Any, limit: int = 4000) -> str:
|
| 25 |
+
"""Coerce any message content to a string that Gemini will accept."""
|
| 26 |
+
if content is None:
|
| 27 |
+
return ""
|
| 28 |
+
if isinstance(content, str):
|
| 29 |
+
return content
|
| 30 |
+
try:
|
| 31 |
+
s = json.dumps(content, default=str, ensure_ascii=False)
|
| 32 |
+
except Exception:
|
| 33 |
+
s = str(content)
|
| 34 |
+
# lightly truncate huge tool dumps
|
| 35 |
+
return (s[:limit] + " …") if len(s) > limit else s
|
| 36 |
+
|
| 37 |
+
def _sanitize_messages(msgs: list[Any]) -> list[Any]:
|
| 38 |
+
"""Keep only system/human/assistant messages and ensure content is str."""
|
| 39 |
+
clean = []
|
| 40 |
+
for m in msgs or []:
|
| 41 |
+
# Drop raw ToolMessage or unknown roles (Gemini doesn't accept them)
|
| 42 |
+
role = getattr(m, "type", None) or getattr(m, "role", None) or ""
|
| 43 |
+
if isinstance(m, ToolMessage) or role == "tool":
|
| 44 |
+
# Optionally compress tool outputs into a short assistant line instead:
|
| 45 |
+
txt = _to_text(getattr(m, "content", None))
|
| 46 |
+
if txt:
|
| 47 |
+
clean.append(AIMessage(content=f"[Tool result] {txt}"))
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
c = _to_text(getattr(m, "content", None))
|
| 51 |
+
if isinstance(m, SystemMessage):
|
| 52 |
+
clean.append(SystemMessage(content=c))
|
| 53 |
+
elif isinstance(m, HumanMessage):
|
| 54 |
+
clean.append(HumanMessage(content=c))
|
| 55 |
+
elif isinstance(m, AIMessage):
|
| 56 |
+
clean.append(AIMessage(content=c))
|
| 57 |
+
else:
|
| 58 |
+
# Unknown BaseMessage; best-effort map by role string
|
| 59 |
+
r = str(role).lower()
|
| 60 |
+
if r == "system":
|
| 61 |
+
clean.append(SystemMessage(content=c))
|
| 62 |
+
elif r in ("human", "user"):
|
| 63 |
+
clean.append(HumanMessage(content=c))
|
| 64 |
+
elif r in ("assistant", "ai", "aimessage"):
|
| 65 |
+
clean.append(AIMessage(content=c))
|
| 66 |
+
# else: ignore silently
|
| 67 |
+
return clean
|
| 68 |
+
|
| 69 |
+
TOOLS = [tool_inspect_dataset, tool_sota_preprocessing, tool_list_steps, tool_describe_step, tool_propose_plan, tool_run_step, tool_list_versions, tool_reset_to_version]
|
| 70 |
+
|
| 71 |
+
SYSTEM_PRIMER = (
|
| 72 |
+
"You are a data-quality assistant.\n"
|
| 73 |
+
"\n"
|
| 74 |
+
"Workflow:\n"
|
| 75 |
+
"1) Call inspect_dataset() to summarize columns/dtypes and GUESS task/label.\n"
|
| 76 |
+
" • If you are NOT SURE about the task (or the label for supervised tasks), ASK the user to confirm and END THE TURN.\n"
|
| 77 |
+
" • Do NOT call sota_preprocessing until the user explicitly confirms the task (and label if supervised).\n"
|
| 78 |
+
" Acceptable confirmations include messages like: "
|
| 79 |
+
" 'task=classification label=HARDSHIP_INDEX', 'Task: regression', or 'Unsupervised'.\n"
|
| 80 |
+
"2) After the user confirms, call sota_preprocessing(task, modality, ...) and PRESENT a brief 'SOTA Evidence' section (3–6 bullets with titles and links from the tool).\n"
|
| 81 |
+
"3) Call list_steps() and map SOTA insights to the available tools. Produce a plan (no execution yet); cite up to 2 SOTA sources per step.\n"
|
| 82 |
+
"4) Ask: 'Which step should we execute first?' Do NOT call run_step until the user explicitly picks.\n"
|
| 83 |
+
"5) After the user picks, call describe_step(name) and list ONLY real parameters from the tool. Ask for missing/optional params and confirm them.\n"
|
| 84 |
+
"6) Execute with run_step(name, params_json). Version controls inside params_json when relevant:\n"
|
| 85 |
+
" • source: 'current' | 'prev' | 'base' | '@-1' | '@-2' | <int>\n"
|
| 86 |
+
" • dry_run: true|false (preview without mutating)\n"
|
| 87 |
+
" • new_version: true|false (create new snapshot vs replace current)\n"
|
| 88 |
+
" Avoid loops: if the same step+params just ran, ask to change parameters or source.\n"
|
| 89 |
+
"7) Summarize results; optionally call list_versions() and offer reset_to_version(spec). If helpful, research again before proposing next steps.\n"
|
| 90 |
+
"\n"
|
| 91 |
+
"Rules:\n"
|
| 92 |
+
"- Return exactly one tool call at a time.\n"
|
| 93 |
+
"- Never call sota_preprocessing before explicit task confirmation.\n"
|
| 94 |
+
"- Never call run_step without an explicit user choice.\n"
|
| 95 |
+
"- When users ask about parameters, use describe_step (or list_steps) and answer ONLY from tool output.\n"
|
| 96 |
+
"- Reject parameters that are not in the tool signature.\n"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class AgentState(TypedDict):
|
| 101 |
+
messages: List[Any]
|
| 102 |
+
df_payload: Optional[Dict[str, Any]]
|
| 103 |
+
results: List[Dict[str, Any]]
|
| 104 |
+
steps_taken: int
|
| 105 |
+
max_steps: int
|
| 106 |
+
confirmed_step: Optional[str]
|
| 107 |
+
confirmed_params: Dict[str, Any]
|
| 108 |
+
last_task: Optional[str]
|
| 109 |
+
plan: Optional[Dict[str, Any]]
|
| 110 |
+
|
| 111 |
+
def make_agent_node(llm):
|
| 112 |
+
"""LLM emits tool calls; we sanitize history and ALWAYS append an AIMessage."""
|
| 113 |
+
llm_with_tools = llm.bind_tools(TOOLS)
|
| 114 |
+
|
| 115 |
+
def _node(state: AgentState) -> AgentState:
|
| 116 |
+
d = (state.get("df_payload") or {}).get("data", {})
|
| 117 |
+
rows = len(d.get("data", []) or [])
|
| 118 |
+
cols = len(d.get("columns", []) or [])
|
| 119 |
+
shape_note = SystemMessage(content=f"Current dataset shape: {rows} rows × {cols} columns.")
|
| 120 |
+
|
| 121 |
+
history = _sanitize_messages(state.get("messages", []))
|
| 122 |
+
inputs = [SystemMessage(content=SYSTEM_PRIMER), *history, shape_note]
|
| 123 |
+
|
| 124 |
+
ai = llm_with_tools.invoke(inputs)
|
| 125 |
+
# guard: ensure we append an AIMessage object
|
| 126 |
+
if not isinstance(ai, AIMessage):
|
| 127 |
+
ai = AIMessage(content=_to_text(getattr(ai, "content", ai)))
|
| 128 |
+
|
| 129 |
+
state["messages"] = state["messages"] + [ai]
|
| 130 |
+
# debug
|
| 131 |
+
# print("DEBUG roles after agent:", [getattr(m, "type", None) or getattr(m, "role", None) for m in state["messages"]])
|
| 132 |
+
return state
|
| 133 |
+
|
| 134 |
+
return _node
|
| 135 |
+
|
| 136 |
+
def tools_exec_node():
|
| 137 |
+
"""
|
| 138 |
+
Execute tools only here, after injecting df_payload into runtime context.
|
| 139 |
+
Also updates state with tool outputs (summary/SOTA/plan/step_result).
|
| 140 |
+
"""
|
| 141 |
+
tool_node = ToolNode(TOOLS)
|
| 142 |
+
|
| 143 |
+
def _node(state: AgentState) -> AgentState:
|
| 144 |
+
# Inject dataset into runtime context BEFORE any tool executes
|
| 145 |
+
set_df_payload(state.get("df_payload"))
|
| 146 |
+
|
| 147 |
+
# If no dataset at all, be friendly and stop
|
| 148 |
+
if state.get("df_payload") is None:
|
| 149 |
+
state["messages"].append(type(state["messages"][-1])(content="I don't have a dataset yet. Please upload one."))
|
| 150 |
+
return state
|
| 151 |
+
|
| 152 |
+
# Hard gate: block run_step unless user confirmed a step
|
| 153 |
+
last = state["messages"][-1]
|
| 154 |
+
tool_calls = getattr(last, "tool_calls", None) or []
|
| 155 |
+
for c in tool_calls:
|
| 156 |
+
if c.get("name") == "run_step":
|
| 157 |
+
intended = (c.get("args") or {}).get("name")
|
| 158 |
+
if intended and intended != state.get("confirmed_step"):
|
| 159 |
+
state["messages"].append(type(last)(content="I have a plan ready. Which step should we run first?"))
|
| 160 |
+
return state
|
| 161 |
+
|
| 162 |
+
# Actually execute the tool(s) requested by the last assistant message
|
| 163 |
+
out = tool_node.invoke({"messages": state["messages"]})
|
| 164 |
+
# Append ONLY new ToolMessages; do NOT overwrite the conversation
|
| 165 |
+
new_msgs = [m for m in out["messages"] if isinstance(m, ToolMessage)]
|
| 166 |
+
if not new_msgs:
|
| 167 |
+
# fallback: if provider returned the whole list, take the tail
|
| 168 |
+
if len(out["messages"]) > len(state["messages"]):
|
| 169 |
+
new_msgs = out["messages"][len(state["messages"]):]
|
| 170 |
+
else:
|
| 171 |
+
new_msgs = out["messages"]
|
| 172 |
+
|
| 173 |
+
state["messages"] = state["messages"] + new_msgs
|
| 174 |
+
|
| 175 |
+
# Parse the most recent tool payload (dict in .content)
|
| 176 |
+
payload = new_msgs[-1].content if new_msgs else None
|
| 177 |
+
if isinstance(payload, dict):
|
| 178 |
+
typ = payload.get("type")
|
| 179 |
+
if typ == "dataset_summary":
|
| 180 |
+
set_df_summary(payload)
|
| 181 |
+
state["last_task"] = payload.get("task_guess")
|
| 182 |
+
elif typ == "sota":
|
| 183 |
+
set_sota_bundled(payload.get("bundled_results") or [])
|
| 184 |
+
elif typ == "plan":
|
| 185 |
+
state["plan"] = payload
|
| 186 |
+
elif typ == "step_result":
|
| 187 |
+
state["df_payload"] = payload["df"]
|
| 188 |
+
set_df_payload(state["df_payload"])
|
| 189 |
+
state["results"].append({"name": payload["name"], "stats": payload["stats"]})
|
| 190 |
+
state["steps_taken"] += 1
|
| 191 |
+
state["confirmed_step"] = None
|
| 192 |
+
state["confirmed_params"] = {}
|
| 193 |
+
|
| 194 |
+
# print("DEBUG roles after tools:", [getattr(m, "type", None) or getattr(m, "role", None) for m in state["messages"]])
|
| 195 |
+
return state
|
| 196 |
+
|
| 197 |
+
return _node
|
| 198 |
+
|
| 199 |
+
def should_continue(state: AgentState) -> str:
|
| 200 |
+
last = state["messages"][-1]
|
| 201 |
+
if state.get("steps_taken", 0) >= state.get("max_steps", 8):
|
| 202 |
+
return "end"
|
| 203 |
+
# Continue if the last assistant message contains tool calls
|
| 204 |
+
return "continue" if getattr(last, "tool_calls", None) else "end"
|
| 205 |
+
|
| 206 |
+
def build_app(llm):
|
| 207 |
+
g = StateGraph(AgentState)
|
| 208 |
+
g.add_node("agent", make_agent_node(llm))
|
| 209 |
+
g.add_node("tools", tools_exec_node())
|
| 210 |
+
|
| 211 |
+
g.add_edge(START, "agent")
|
| 212 |
+
g.add_conditional_edges("agent", should_continue, {"continue": "tools", "end": END})
|
| 213 |
+
g.add_edge("tools", "agent")
|
| 214 |
+
|
| 215 |
+
return g.compile()
|
agent/runtime_ctx.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# agent/runtime_ctx.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from contextvars import ContextVar
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
# ---------------- Versioned dataset store ----------------
|
| 8 |
+
# Each new mutating step can create a new "version".
|
| 9 |
+
# You can address versions as: "current", "base", "prev", "@-1", "@-2", "@3" (0-based).
|
| 10 |
+
_VERSIONS_CV: ContextVar[Optional[List[Dict[str, Any]]]] = ContextVar("VERSIONS", default=None)
|
| 11 |
+
_CUR_INDEX_CV: ContextVar[int] = ContextVar("CUR_INDEX", default=-1)
|
| 12 |
+
_VERS_META_CV: ContextVar[Optional[List[Dict[str, Any]]]] = ContextVar("VERS_META", default=None)
|
| 13 |
+
|
| 14 |
+
# Legacy singletons (fallback across tasks)
|
| 15 |
+
_STORE: Dict[str, Any] = {
|
| 16 |
+
"versions": [], # list of df_payloads
|
| 17 |
+
"version_meta": [], # parallel list of metadata dicts
|
| 18 |
+
"cur_index": -1,
|
| 19 |
+
# kept for backward compat with old getters:
|
| 20 |
+
"df_payload": None, # alias of current
|
| 21 |
+
"base_df_payload": None, # alias of versions[0]
|
| 22 |
+
"sota_bundled": None,
|
| 23 |
+
"df_summary": None,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
_SOTA_BUNDLED_CV: ContextVar[Optional[list]] = ContextVar("SOTA_BUNDLED", default=None)
|
| 27 |
+
_DF_SUMMARY_CV: ContextVar[Optional[Dict[str, Any]]] = ContextVar("DF_SUMMARY", default=None)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# -------- internal helpers --------
|
| 31 |
+
def _get_versions() -> List[Dict[str, Any]]:
|
| 32 |
+
return _VERSIONS_CV.get() or _STORE["versions"]
|
| 33 |
+
|
| 34 |
+
def _get_meta() -> List[Dict[str, Any]]:
|
| 35 |
+
return _VERS_META_CV.get() or _STORE["version_meta"]
|
| 36 |
+
|
| 37 |
+
def _set_versions(vers: List[Dict[str, Any]], meta: List[Dict[str, Any]], cur: int) -> None:
|
| 38 |
+
_VERSIONS_CV.set(vers)
|
| 39 |
+
_VERS_META_CV.set(meta)
|
| 40 |
+
_CUR_INDEX_CV.set(cur)
|
| 41 |
+
_STORE["versions"] = vers
|
| 42 |
+
_STORE["version_meta"] = meta
|
| 43 |
+
_STORE["cur_index"] = cur
|
| 44 |
+
# keep legacy aliases in sync
|
| 45 |
+
_STORE["df_payload"] = vers[cur] if (0 <= cur < len(vers)) else None
|
| 46 |
+
_STORE["base_df_payload"] = vers[0] if vers else None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# =========================
|
| 50 |
+
# Init / Set / Annotate
|
| 51 |
+
# =========================
|
| 52 |
+
def init_dataset(p: Optional[Dict[str, Any]]) -> None:
|
| 53 |
+
"""Initialize version stack with a single BASE version."""
|
| 54 |
+
vers = [] if p is None else [p]
|
| 55 |
+
meta = [] if p is None else [dict(tag="base")]
|
| 56 |
+
cur = -1 if p is None else 0
|
| 57 |
+
_set_versions(vers, meta, cur)
|
| 58 |
+
|
| 59 |
+
def set_df_payload(p: Optional[Dict[str, Any]], *, new_version: bool = True) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Set CURRENT dataset.
|
| 62 |
+
- new_version=True: truncate any forward history and append p (like a new commit).
|
| 63 |
+
- new_version=False: replace the current version in place (no new snapshot).
|
| 64 |
+
"""
|
| 65 |
+
vers = list(_get_versions())
|
| 66 |
+
meta = list(_get_meta())
|
| 67 |
+
cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
|
| 68 |
+
|
| 69 |
+
if cur < 0 or not vers:
|
| 70 |
+
# not initialized yet
|
| 71 |
+
init_dataset(p)
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
if new_version:
|
| 75 |
+
# drop any versions after current (no branching for simplicity)
|
| 76 |
+
vers = vers[:cur + 1]
|
| 77 |
+
meta = meta[:cur + 1]
|
| 78 |
+
vers.append(p)
|
| 79 |
+
meta.append({})
|
| 80 |
+
cur = len(vers) - 1
|
| 81 |
+
else:
|
| 82 |
+
vers[cur] = p
|
| 83 |
+
|
| 84 |
+
_set_versions(vers, meta, cur)
|
| 85 |
+
|
| 86 |
+
def annotate_current(**kv) -> None:
|
| 87 |
+
"""Attach metadata to the current version (e.g., step/params/stats)."""
|
| 88 |
+
vers = list(_get_versions())
|
| 89 |
+
meta = list(_get_meta())
|
| 90 |
+
cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
|
| 91 |
+
if 0 <= cur < len(meta):
|
| 92 |
+
meta[cur] = {**meta[cur], **kv}
|
| 93 |
+
_set_versions(vers, meta, cur)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# =========================
|
| 97 |
+
# Getters / Navigation
|
| 98 |
+
# =========================
|
| 99 |
+
def _resolve_index(spec: Union[str, int, None]) -> int:
|
| 100 |
+
vers = _get_versions()
|
| 101 |
+
cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
|
| 102 |
+
if spec is None or spec == "current":
|
| 103 |
+
return cur
|
| 104 |
+
if spec == "base":
|
| 105 |
+
return 0 if vers else -1
|
| 106 |
+
if spec == "prev":
|
| 107 |
+
return max(-1, cur - 1)
|
| 108 |
+
if isinstance(spec, int):
|
| 109 |
+
idx = spec if spec >= 0 else len(vers) + spec
|
| 110 |
+
return idx
|
| 111 |
+
# strings like "@-1", "@3"
|
| 112 |
+
if isinstance(spec, str) and spec.startswith("@"):
|
| 113 |
+
try:
|
| 114 |
+
n = int(spec[1:])
|
| 115 |
+
except Exception:
|
| 116 |
+
return cur
|
| 117 |
+
idx = n if n >= 0 else len(vers) + n
|
| 118 |
+
return idx
|
| 119 |
+
return cur
|
| 120 |
+
|
| 121 |
+
def get_df_payload(version: Union[str, int, None] = None) -> Optional[Dict[str, Any]]:
|
| 122 |
+
"""Return dataset payload for the requested version (default: current)."""
|
| 123 |
+
vers = _get_versions()
|
| 124 |
+
idx = _resolve_index(version)
|
| 125 |
+
if 0 <= idx < len(vers):
|
| 126 |
+
return vers[idx]
|
| 127 |
+
# legacy fallback
|
| 128 |
+
return _STORE["df_payload"]
|
| 129 |
+
|
| 130 |
+
def get_base_df_payload() -> Optional[Dict[str, Any]]:
|
| 131 |
+
return get_df_payload("base")
|
| 132 |
+
|
| 133 |
+
def get_prev_df_payload() -> Optional[Dict[str, Any]]:
|
| 134 |
+
return get_df_payload("prev")
|
| 135 |
+
|
| 136 |
+
def list_versions() -> Dict[str, Any]:
|
| 137 |
+
"""Lightweight overview for debugging/UI."""
|
| 138 |
+
vers = _get_versions()
|
| 139 |
+
meta = _get_meta()
|
| 140 |
+
cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
|
| 141 |
+
return {
|
| 142 |
+
"count": len(vers),
|
| 143 |
+
"current_index": cur,
|
| 144 |
+
"has_base": bool(vers),
|
| 145 |
+
"meta": meta, # [{tag:..., step:..., params:...}, ...]
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def reset_current_to(version: Union[str, int]) -> None:
|
| 149 |
+
"""Move the current pointer to a prior version (no deletion)."""
|
| 150 |
+
vers = _get_versions()
|
| 151 |
+
meta = _get_meta()
|
| 152 |
+
idx = _resolve_index(version)
|
| 153 |
+
if 0 <= idx < len(vers):
|
| 154 |
+
_set_versions(vers, meta, idx)
|
| 155 |
+
|
| 156 |
+
def reset_current_to_base() -> None:
|
| 157 |
+
reset_current_to("base")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# =========================
|
| 161 |
+
# SOTA / Summary passthrough
|
| 162 |
+
# =========================
|
| 163 |
+
def set_sota_bundled(b: Optional[list]) -> None:
|
| 164 |
+
_SOTA_BUNDLED_CV.set(b)
|
| 165 |
+
_STORE["sota_bundled"] = b
|
| 166 |
+
|
| 167 |
+
def get_sota_bundled() -> Optional[list]:
|
| 168 |
+
return _SOTA_BUNDLED_CV.get() or _STORE["sota_bundled"]
|
| 169 |
+
|
| 170 |
+
def set_df_summary(s: Optional[Dict[str, Any]]) -> None:
|
| 171 |
+
_DF_SUMMARY_CV.set(s)
|
| 172 |
+
_STORE["df_summary"] = s
|
| 173 |
+
|
| 174 |
+
def get_df_summary() -> Optional[Dict[str, Any]]:
|
| 175 |
+
return _DF_SUMMARY_CV.get() or _STORE["df_summary"]
|
agent/simple_chat.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 4 |
+
from langchain_core.messages import HumanMessage
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def simple_chat(
|
| 11 |
+
question: str,
|
| 12 |
+
context: str,
|
| 13 |
+
image_paths: Optional[List[str]] = None
|
| 14 |
+
) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Simple chat function that answers questions based on context and optional images.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
question: User's question
|
| 20 |
+
context: Context information (e.g., dataset summary, analysis results)
|
| 21 |
+
image_paths: Optional list of image file paths to include
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
AI response as a string
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
# Initialize LLM
|
| 28 |
+
llm = ChatGoogleGenerativeAI(
|
| 29 |
+
model="gemini-2.0-flash-exp",
|
| 30 |
+
temperature=0,
|
| 31 |
+
api_key=os.getenv("GOOGLE_API_KEY"),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Build the prompt
|
| 35 |
+
prompt = f"""You are a helpful data analysis assistant.
|
| 36 |
+
|
| 37 |
+
Context:
|
| 38 |
+
{context}
|
| 39 |
+
|
| 40 |
+
User Question: {question}
|
| 41 |
+
|
| 42 |
+
Please provide a clear, concise answer based on the context provided."""
|
| 43 |
+
|
| 44 |
+
# Handle images if provided
|
| 45 |
+
if image_paths:
|
| 46 |
+
content = [{"type": "text", "text": prompt}]
|
| 47 |
+
|
| 48 |
+
for img_path in image_paths:
|
| 49 |
+
if os.path.exists(img_path):
|
| 50 |
+
import base64
|
| 51 |
+
with open(img_path, "rb") as f:
|
| 52 |
+
img_data = base64.b64encode(f.read()).decode()
|
| 53 |
+
|
| 54 |
+
# Determine image type
|
| 55 |
+
ext = os.path.splitext(img_path)[1].lower()
|
| 56 |
+
mime_type = {
|
| 57 |
+
'.png': 'image/png',
|
| 58 |
+
'.jpg': 'image/jpeg',
|
| 59 |
+
'.jpeg': 'image/jpeg',
|
| 60 |
+
'.gif': 'image/gif',
|
| 61 |
+
'.webp': 'image/webp'
|
| 62 |
+
}.get(ext, 'image/png')
|
| 63 |
+
|
| 64 |
+
content.append({
|
| 65 |
+
"type": "image_url",
|
| 66 |
+
"image_url": f"data:{mime_type};base64,{img_data}"
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
message = HumanMessage(content=content)
|
| 70 |
+
else:
|
| 71 |
+
message = HumanMessage(content=prompt)
|
| 72 |
+
|
| 73 |
+
# Get response
|
| 74 |
+
response = llm.invoke([message])
|
| 75 |
+
|
| 76 |
+
# Extract text from response
|
| 77 |
+
if hasattr(response, 'content'):
|
| 78 |
+
return str(response.content)
|
| 79 |
+
return str(response)
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
return f"Error: {str(e)}"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Example usage:
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
# Simple text-only example
|
| 88 |
+
context = """
|
| 89 |
+
Dataset: Customer Sales Data
|
| 90 |
+
- 1000 rows, 15 columns
|
| 91 |
+
- Label: purchase_made (binary)
|
| 92 |
+
- Task: Classification
|
| 93 |
+
- Missing values: 5% in age column
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
question = "What's the main task for this dataset?"
|
| 97 |
+
response = simple_chat(question, context)
|
| 98 |
+
print(response)
|
| 99 |
+
|
| 100 |
+
# With images
|
| 101 |
+
question2 = "What do you see in the visualization?"
|
| 102 |
+
response2 = simple_chat(question2, context, image_paths=["/path/to/plot.png"])
|
| 103 |
+
print(response2)
|
agent/tools.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from langchain.tools import tool
|
| 7 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 8 |
+
|
| 9 |
+
from pipeline.deduplication import find_near_duplicates
|
| 10 |
+
from pipeline.featurizer import custom_featurizer
|
| 11 |
+
from pipeline.issues import find_issues
|
| 12 |
+
from pipeline.utils_cool import (
|
| 13 |
+
df_from_payload,
|
| 14 |
+
df_to_payload,
|
| 15 |
+
get_signature_dict,
|
| 16 |
+
guess_task_and_label,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from .runtime_ctx import (
|
| 20 |
+
get_df_payload, # now supports version spec (None|'current'|'prev'|'base'|'@-1'|int)
|
| 21 |
+
get_df_summary,
|
| 22 |
+
get_sota_bundled,
|
| 23 |
+
set_df_payload, # commit new dataset version (or replace)
|
| 24 |
+
set_df_summary,
|
| 25 |
+
set_sota_bundled,
|
| 26 |
+
)
|
| 27 |
+
from .runtime_ctx import (
|
| 28 |
+
list_versions as _list_versions_state,
|
| 29 |
+
)
|
| 30 |
+
from .runtime_ctx import (
|
| 31 |
+
reset_current_to as _reset_current_to,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Registry of runnable steps (names used by the agent/UI)
|
| 35 |
+
STEP_FUNCS = {
|
| 36 |
+
"dedup": find_near_duplicates,
|
| 37 |
+
"featurize": custom_featurizer,
|
| 38 |
+
"find_label_issues": find_issues,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@tool("inspect_dataset", return_direct=True)
|
| 43 |
+
def tool_inspect_dataset() -> Dict[str, Any]:
|
| 44 |
+
"""
|
| 45 |
+
Summarize the CURRENT dataset (no arguments required).
|
| 46 |
+
|
| 47 |
+
Behavior:
|
| 48 |
+
• Reads the dataset from the runtime context (set by the graph).
|
| 49 |
+
• Returns a compact summary of columns, dtypes, shape, and a guessed label/task.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
{
|
| 53 |
+
"type": "dataset_summary",
|
| 54 |
+
"columns": [...],
|
| 55 |
+
"dtypes": {col: dtype, ...},
|
| 56 |
+
"shape": (rows, cols),
|
| 57 |
+
"label_guess": "<name or None>",
|
| 58 |
+
"task_guess": "classification|regression|unsupervised",
|
| 59 |
+
"issues": [ ... ] # e.g., missing labels, single-class, etc.
|
| 60 |
+
}
|
| 61 |
+
"""
|
| 62 |
+
df_payload = get_df_payload() # default: current version
|
| 63 |
+
if df_payload is None:
|
| 64 |
+
raise RuntimeError("inspect_dataset: no dataset available in runtime context.")
|
| 65 |
+
df = df_from_payload(df_payload)
|
| 66 |
+
summary = guess_task_and_label(df)
|
| 67 |
+
# keep context fresh for downstream tools
|
| 68 |
+
set_df_summary(summary)
|
| 69 |
+
return {"type": "dataset_summary", **summary}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@tool("sota_preprocessing", return_direct=True)
|
| 73 |
+
def tool_sota_preprocessing(
|
| 74 |
+
task: Optional[str] = None,
|
| 75 |
+
modality: Optional[str] = None,
|
| 76 |
+
domain: Optional[str] = None,
|
| 77 |
+
target: Optional[str] = None,
|
| 78 |
+
) -> Dict[str, Any]:
|
| 79 |
+
"""
|
| 80 |
+
Search state-of-the-art preprocessing best practices (modality-aware).
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
task: e.g., "classification", "regression", "segmentation", "NER", "ASR", "forecasting".
|
| 84 |
+
If omitted, inferred from the dataset summary if available.
|
| 85 |
+
modality: one of {"tabular","text","image","audio","video","time_series","graph","multimodal"}.
|
| 86 |
+
domain: optional domain context (e.g., "clinical", "finance").
|
| 87 |
+
target: optional target structure (e.g., "segmentation masks", "bounding boxes").
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
{
|
| 91 |
+
"type": "sota",
|
| 92 |
+
"task": ...,
|
| 93 |
+
"modality": ...,
|
| 94 |
+
"domain": ...,
|
| 95 |
+
"target": ...,
|
| 96 |
+
"queries": [...],
|
| 97 |
+
"bundled_results": [{ "query": q, "results": <tavily-results> }, ...],
|
| 98 |
+
"results": <first tavily-results batch>
|
| 99 |
+
}
|
| 100 |
+
"""
|
| 101 |
+
df_summary = get_df_summary() or {}
|
| 102 |
+
if not task:
|
| 103 |
+
task = df_summary.get("task_guess") or "classification"
|
| 104 |
+
|
| 105 |
+
yr = "2024 2025"
|
| 106 |
+
m = (modality or "").lower().strip()
|
| 107 |
+
|
| 108 |
+
modality_terms = {
|
| 109 |
+
"tabular": ["imputation", "encoding", "scaling", "outliers", "leakage prevention"],
|
| 110 |
+
"text": ["tokenization", "normalization", "subword", "BPE", "SentencePiece", "stopwords", "lemmatization", "augmentation"],
|
| 111 |
+
"image": ["normalization", "resizing", "color space", "augmentation", "RandAugment", "MixUp", "CutMix"],
|
| 112 |
+
"audio": ["resampling", "log-mel spectrogram", "MFCC", "pre-emphasis", "SpecAugment", "denoising"],
|
| 113 |
+
"time_series": ["resampling", "windowing", "detrending", "imputation", "outlier detection", "scaling"],
|
| 114 |
+
"video": ["frame sampling", "temporal augmentation", "clip normalization", "optical flow"],
|
| 115 |
+
"graph": ["feature normalization", "self-loops", "adjacency normalization", "sparsification"],
|
| 116 |
+
"multimodal": ["alignment", "synchronization", "fusion", "tokenization"],
|
| 117 |
+
}
|
| 118 |
+
m_terms = modality_terms.get(m, [])
|
| 119 |
+
|
| 120 |
+
# Build candidate queries
|
| 121 |
+
queries: List[str] = []
|
| 122 |
+
queries.append(f"state of the art preprocessing {task} {yr}")
|
| 123 |
+
queries.append(f"best practices data preprocessing {task} {yr}")
|
| 124 |
+
if m:
|
| 125 |
+
queries.append(f"{m} {task} preprocessing best practices {yr}")
|
| 126 |
+
if domain:
|
| 127 |
+
queries.append(f"{domain} {m or ''} {task} preprocessing best practices {yr}".strip())
|
| 128 |
+
if target:
|
| 129 |
+
queries.append(f"{m or ''} {task} {target} preprocessing pipeline {yr}".strip())
|
| 130 |
+
if m_terms:
|
| 131 |
+
queries.append(f"{m} {task} preprocessing {' '.join(m_terms)} {yr}")
|
| 132 |
+
|
| 133 |
+
# Deduplicate, preserve order
|
| 134 |
+
seen = set()
|
| 135 |
+
queries = [q for q in (q.strip() for q in queries) if q and (q not in seen and not seen.add(q))]
|
| 136 |
+
|
| 137 |
+
tavily = TavilySearchResults(k=6)
|
| 138 |
+
bundled: List[Dict[str, Any]] = [{"query": q, "results": tavily.invoke({"query": q})} for q in queries]
|
| 139 |
+
flat_first = bundled[0]["results"] if (bundled and "results" in bundled[0]) else []
|
| 140 |
+
|
| 141 |
+
# persist for planning
|
| 142 |
+
set_sota_bundled(bundled)
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"type": "sota",
|
| 146 |
+
"task": task,
|
| 147 |
+
"modality": m or "unknown",
|
| 148 |
+
"domain": domain,
|
| 149 |
+
"target": target,
|
| 150 |
+
"queries": queries,
|
| 151 |
+
"bundled_results": bundled,
|
| 152 |
+
"results": flat_first,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
@tool("describe_step", return_direct=True)
|
| 156 |
+
def tool_describe_step(name: str) -> Dict[str, Any]:
|
| 157 |
+
"""
|
| 158 |
+
Return the exact docstring + parameter schema for a single step by name.
|
| 159 |
+
This prevents the model from inventing params.
|
| 160 |
+
"""
|
| 161 |
+
if name not in STEP_FUNCS:
|
| 162 |
+
raise ValueError(f"Unknown step '{name}'. Available: {list(STEP_FUNCS)}")
|
| 163 |
+
fn = STEP_FUNCS[name]
|
| 164 |
+
sig = get_signature_dict(fn) # your util that introspects defaults/annotations
|
| 165 |
+
return {"type": "step_description", "name": name, **sig}
|
| 166 |
+
|
| 167 |
+
@tool("list_steps", return_direct=True)
|
| 168 |
+
def tool_list_steps() -> Dict[str, Any]:
|
| 169 |
+
"""
|
| 170 |
+
List available pipeline steps (name, docstring, and signature).
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
{
|
| 174 |
+
"type": "steps",
|
| 175 |
+
"steps": [
|
| 176 |
+
{
|
| 177 |
+
"name": "dedup" | "featurize" | "find_label_issues",
|
| 178 |
+
"doc": "<docstring>",
|
| 179 |
+
"params": [{"name": "...", "default": ..., "annotation": "...", "kind": "..."}]
|
| 180 |
+
}, ...
|
| 181 |
+
]
|
| 182 |
+
}
|
| 183 |
+
"""
|
| 184 |
+
return {
|
| 185 |
+
"type": "steps",
|
| 186 |
+
"steps": [{"name": n, **get_signature_dict(fn)} for n, fn in STEP_FUNCS.items()],
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@tool("propose_plan", return_direct=True)
|
| 191 |
+
def tool_propose_plan(
|
| 192 |
+
task: Optional[str] = None,
|
| 193 |
+
modality: Optional[str] = None,
|
| 194 |
+
) -> Dict[str, Any]:
|
| 195 |
+
"""
|
| 196 |
+
Propose an ordered preprocessing plan grounded in SOTA + dataset summary.
|
| 197 |
+
(Planning only — does not execute steps.)
|
| 198 |
+
"""
|
| 199 |
+
df_summary = get_df_summary() or {}
|
| 200 |
+
bundled = get_sota_bundled() or []
|
| 201 |
+
|
| 202 |
+
if not task:
|
| 203 |
+
task = df_summary.get("task_guess") or "classification"
|
| 204 |
+
label_guess = df_summary.get("label_guess")
|
| 205 |
+
|
| 206 |
+
KEYWORDS = {
|
| 207 |
+
"dedup": {"duplicate", "near-duplicate", "near duplicate", "dupe", "dedup", "similarity", "knn", "kNN"},
|
| 208 |
+
"featurize": {
|
| 209 |
+
"impute", "imputation", "encoding", "one-hot", "scale", "scaling",
|
| 210 |
+
"normalize", "normalization", "standardize", "tfidf", "tokenization",
|
| 211 |
+
"lemmatization", "augmentation"
|
| 212 |
+
},
|
| 213 |
+
"find_label_issues": {"label noise", "noisy labels", "cleanlab", "confident learning", "label issues", "weak labels"},
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
def _score(text: str, keys: set[str]) -> int:
|
| 217 |
+
t = (text or "").lower()
|
| 218 |
+
return sum(1 for k in keys if k in t)
|
| 219 |
+
|
| 220 |
+
hits = {"dedup": 0, "featurize": 0, "find_label_issues": 0}
|
| 221 |
+
evidence: Dict[str, List[Dict[str, str]]] = {"dedup": [], "featurize": [], "find_label_issues": []}
|
| 222 |
+
|
| 223 |
+
for pack in bundled:
|
| 224 |
+
q = pack.get("query", "")
|
| 225 |
+
for item in (pack.get("results") or []):
|
| 226 |
+
title = item.get("title", "")
|
| 227 |
+
content = item.get("content", "")
|
| 228 |
+
url = item.get("url", "")
|
| 229 |
+
for step, keys in KEYWORDS.items():
|
| 230 |
+
s = _score(f"{q} {title} {content}", keys)
|
| 231 |
+
if s > 0:
|
| 232 |
+
hits[step] += s
|
| 233 |
+
if len(evidence[step]) < 5:
|
| 234 |
+
evidence[step].append({"query": q, "title": title, "url": url})
|
| 235 |
+
|
| 236 |
+
options: List[Dict[str, Any]] = []
|
| 237 |
+
|
| 238 |
+
if hits["dedup"] > 0 or modality in {None, "tabular", "text", "image", "time_series"}:
|
| 239 |
+
options.append(
|
| 240 |
+
{
|
| 241 |
+
"reason": "SOTA emphasizes handling near-duplicates early" if hits["dedup"] else "Practical first step to prevent leakage/skew",
|
| 242 |
+
"step": "dedup",
|
| 243 |
+
"params": {"threshold": 0.95, "metric": "cosine"},
|
| 244 |
+
"evidence": evidence["dedup"][:3],
|
| 245 |
+
}
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
options.append(
|
| 249 |
+
{
|
| 250 |
+
"reason": "SOTA emphasizes robust imputation/encoding/scaling" if hits["featurize"] else "Prepare features based on modality",
|
| 251 |
+
"step": "featurize",
|
| 252 |
+
"params": {"nan_strategy": "impute"},
|
| 253 |
+
"evidence": evidence["featurize"][:3],
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if (task == "classification" and label_guess) or hits["find_label_issues"] > 0:
|
| 258 |
+
options.append(
|
| 259 |
+
{
|
| 260 |
+
"reason": "SOTA recommends checking noisy labels" if hits["find_label_issues"] else "Check label quality before training",
|
| 261 |
+
"step": "find_label_issues",
|
| 262 |
+
"params": {"label": label_guess or "<CONFIRM>"},
|
| 263 |
+
"evidence": evidence["find_label_issues"][:3],
|
| 264 |
+
}
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if not options:
|
| 268 |
+
options = [{"reason": "Generic best practice", "step": "featurize", "params": {"nan_strategy": "impute"}, "evidence": []}]
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
"type": "plan",
|
| 272 |
+
"task": task,
|
| 273 |
+
"modality": modality,
|
| 274 |
+
"label_guess": label_guess,
|
| 275 |
+
"options": options,
|
| 276 |
+
"keyword_hits": hits,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@tool("run_step", return_direct=True)
|
| 281 |
+
def tool_run_step(name: str, params_json: str = "") -> Dict[str, Any]:
|
| 282 |
+
"""
|
| 283 |
+
Execute a single pipeline step on the CURRENT dataset (no df argument).
|
| 284 |
+
Returns ONLY a compact summary; the updated df is stored in runtime context.
|
| 285 |
+
"""
|
| 286 |
+
df_payload = get_df_payload()
|
| 287 |
+
if df_payload is None:
|
| 288 |
+
raise RuntimeError("run_step: no dataset available in runtime context.")
|
| 289 |
+
if name not in STEP_FUNCS:
|
| 290 |
+
raise ValueError(f"Unknown step '{name}'. Available: {list(STEP_FUNCS)}")
|
| 291 |
+
|
| 292 |
+
params = json.loads(params_json) if params_json else {}
|
| 293 |
+
if not isinstance(params, dict):
|
| 294 |
+
raise ValueError("params_json must decode to a JSON object")
|
| 295 |
+
|
| 296 |
+
df = df_from_payload(df_payload)
|
| 297 |
+
df_out, stats = STEP_FUNCS[name](df=df, **params)
|
| 298 |
+
df_next = df_out if df_out is not None else df
|
| 299 |
+
|
| 300 |
+
# ✅ update runtime dataset, but DO NOT send it back in the tool message
|
| 301 |
+
set_df_payload(df_to_payload(df_next))
|
| 302 |
+
|
| 303 |
+
# Build a tiny, safe summary for the model
|
| 304 |
+
shape_before = (len(df), len(df.columns))
|
| 305 |
+
shape_after = (len(df_next), len(df_next.columns))
|
| 306 |
+
compact_stats = {k: stats.get(k) for k in [
|
| 307 |
+
"n_rows_before_dedup", "n_near_dupe_pairs", "n_groups",
|
| 308 |
+
"n_rows_flagged_duplicates", "n_rows_after_dedup",
|
| 309 |
+
"metric", "threshold", "k", "total_time_sec"
|
| 310 |
+
] if k in stats}
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
"type": "step_result",
|
| 314 |
+
"name": name,
|
| 315 |
+
"params_used": params,
|
| 316 |
+
"shape_before": shape_before,
|
| 317 |
+
"shape_after": shape_after,
|
| 318 |
+
"stats": compact_stats, # small dict only
|
| 319 |
+
"note": "Dataset updated in runtime context; use list_versions/reset_to_version if needed."
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
# ---------------------------
|
| 323 |
+
# Optional helpers for version control from chat/agent
|
| 324 |
+
# ---------------------------
|
| 325 |
+
@tool("list_versions", return_direct=True)
|
| 326 |
+
def tool_list_versions() -> Dict[str, Any]:
|
| 327 |
+
"""
|
| 328 |
+
Return a lightweight view of the version stack:
|
| 329 |
+
{ count, current_index, has_base, meta: [{...}, ...] }
|
| 330 |
+
"""
|
| 331 |
+
return {"type": "versions", **_list_versions_state()}
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@tool("reset_to_version", return_direct=True)
|
| 335 |
+
def tool_reset_to_version(spec: str) -> Dict[str, Any]:
|
| 336 |
+
"""
|
| 337 |
+
Move CURRENT pointer to a prior version without deleting history.
|
| 338 |
+
spec can be: "base" | "prev" | "@-1" | "@-2" | "3"
|
| 339 |
+
"""
|
| 340 |
+
# accept int or @-k in string form
|
| 341 |
+
try:
|
| 342 |
+
if spec.isdigit():
|
| 343 |
+
_reset_current_to(int(spec))
|
| 344 |
+
else:
|
| 345 |
+
_reset_current_to(spec)
|
| 346 |
+
except Exception as e:
|
| 347 |
+
raise RuntimeError(f"reset_to_version: {e}")
|
| 348 |
+
|
| 349 |
+
df_payload = get_df_payload()
|
| 350 |
+
return {
|
| 351 |
+
"type": "reset",
|
| 352 |
+
"current": _list_versions_state(),
|
| 353 |
+
"df": df_payload,
|
| 354 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 5 |
+
import io
|
| 6 |
+
import base64
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from pipeline.deduplication import find_near_duplicates
|
| 13 |
+
from pipeline.featurizer import custom_featurizer
|
| 14 |
+
from pipeline.issues import find_issues
|
| 15 |
+
from pipeline.pipeline import make_step, run_pipeline
|
| 16 |
+
|
| 17 |
+
from ecg_analyzer import ECGAnalyzer
|
| 18 |
+
from agent.simple_chat import simple_chat
|
| 19 |
+
# ============================================================================
|
| 20 |
+
# ANALYSIS TASK CONFIGURATION
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TaskConfig:
|
| 25 |
+
"""Configuration for each analysis task"""
|
| 26 |
+
name: str
|
| 27 |
+
data_type: str
|
| 28 |
+
requires_params: bool
|
| 29 |
+
param_components: List[Dict[str, Any]]
|
| 30 |
+
output_tabs: List[str]
|
| 31 |
+
|
| 32 |
+
class TaskRegistry:
|
| 33 |
+
"""Registry mapping tasks to their configurations"""
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_config(data_type: str, task_name: str) -> Optional[TaskConfig]:
|
| 37 |
+
"""Get configuration for a specific task"""
|
| 38 |
+
configs = {
|
| 39 |
+
"EHR Data": {
|
| 40 |
+
"Near-Duplicate Detection": TaskConfig(
|
| 41 |
+
name="Near-Duplicate Detection",
|
| 42 |
+
data_type="EHR Data",
|
| 43 |
+
requires_params=True,
|
| 44 |
+
param_components=[
|
| 45 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
|
| 46 |
+
],
|
| 47 |
+
output_tabs=["original", "processed", "summary"]
|
| 48 |
+
),
|
| 49 |
+
"Find Mislabeled Data": TaskConfig(
|
| 50 |
+
name="Find Mislabeled Data",
|
| 51 |
+
data_type="EHR Data",
|
| 52 |
+
requires_params=True,
|
| 53 |
+
param_components=[
|
| 54 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
|
| 55 |
+
],
|
| 56 |
+
output_tabs=["original", "summary"]
|
| 57 |
+
)
|
| 58 |
+
},
|
| 59 |
+
"ECG Data": {
|
| 60 |
+
"ECG Visualization": TaskConfig(
|
| 61 |
+
name="ECG Visualization",
|
| 62 |
+
data_type="ECG Data",
|
| 63 |
+
requires_params=True,
|
| 64 |
+
param_components=[
|
| 65 |
+
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_leads"},
|
| 66 |
+
{"type": "checkboxgroup", "label": "Visualization Types", "elem_id": "ecg_viz_types"}
|
| 67 |
+
],
|
| 68 |
+
output_tabs=["visualization", "summary"]
|
| 69 |
+
),
|
| 70 |
+
"Statistical Summary": TaskConfig(
|
| 71 |
+
name="Statistical Summary",
|
| 72 |
+
data_type="ECG Data",
|
| 73 |
+
requires_params=True,
|
| 74 |
+
param_components=[
|
| 75 |
+
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_stats_leads"}
|
| 76 |
+
],
|
| 77 |
+
output_tabs=["summary", "visualization"]
|
| 78 |
+
)
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
return configs.get(data_type, {}).get(task_name)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def get_tasks_for_data_type(data_type: str) -> List[str]:
|
| 85 |
+
"""Get available tasks for a data type"""
|
| 86 |
+
tasks = {
|
| 87 |
+
"EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
|
| 88 |
+
"ECG Data": ["ECG Visualization", "Statistical Summary"]
|
| 89 |
+
}
|
| 90 |
+
return tasks.get(data_type, [])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ============================================================================
|
| 94 |
+
# ANALYSIS EXECUTION
|
| 95 |
+
# ============================================================================
|
| 96 |
+
|
| 97 |
+
class AnalysisExecutor:
|
| 98 |
+
"""Executes analysis tasks and returns results"""
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 102 |
+
"""Execute near-duplicate detection pipeline"""
|
| 103 |
+
try:
|
| 104 |
+
if not label:
|
| 105 |
+
return "⚠ Label column required", {"original": df, "processed": None, "summary": None}
|
| 106 |
+
|
| 107 |
+
bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
|
| 108 |
+
steps = [
|
| 109 |
+
make_step(find_near_duplicates, name="dedup")(progress=bar),
|
| 110 |
+
make_step(custom_featurizer, name="featurize")(
|
| 111 |
+
label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
|
| 112 |
+
),
|
| 113 |
+
make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
|
| 114 |
+
]
|
| 115 |
+
results_df, summary_list = run_pipeline(steps, df=df)
|
| 116 |
+
bar.close()
|
| 117 |
+
|
| 118 |
+
return "✓ Near-duplicate detection completed", {
|
| 119 |
+
"original": df, "processed": results_df, "summary": summary_list
|
| 120 |
+
}
|
| 121 |
+
except Exception as e:
|
| 122 |
+
return f"✗ Error: {str(e)}", {"original": df, "processed": None, "summary": None}
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 126 |
+
"""Execute mislabeled data detection"""
|
| 127 |
+
try:
|
| 128 |
+
if not label:
|
| 129 |
+
return "⚠ Label column required", {"original": df, "summary": None}
|
| 130 |
+
|
| 131 |
+
summary = {
|
| 132 |
+
"task": "Find Mislabeled Data", "label_column": label, "total_samples": len(df),
|
| 133 |
+
"suspicious_samples": 0, "message": "Mislabeled detection analysis completed"
|
| 134 |
+
}
|
| 135 |
+
return "✓ Mislabeled data analysis completed", {"original": df, "summary": summary}
|
| 136 |
+
except Exception as e:
|
| 137 |
+
return f"✗ Error: {str(e)}", {"original": df, "summary": None}
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def execute_ecg_visualization(df: pd.DataFrame, leads: List[str] = None, viz_types: List[str] = None) -> Tuple[str, Dict[str, Any]]:
|
| 141 |
+
"""Execute ECG visualization using ECGAnalyzer"""
|
| 142 |
+
try:
|
| 143 |
+
# Detect available leads
|
| 144 |
+
available_leads = ECGAnalyzer.detect_leads(df)
|
| 145 |
+
|
| 146 |
+
# Use provided leads or default to all available
|
| 147 |
+
if not leads:
|
| 148 |
+
leads = available_leads if available_leads else []
|
| 149 |
+
|
| 150 |
+
if not leads:
|
| 151 |
+
return "⚠ No ECG leads found in data", {"visualization": None, "summary": None}
|
| 152 |
+
|
| 153 |
+
# Default visualization types
|
| 154 |
+
if not viz_types:
|
| 155 |
+
viz_types = ["Signal Waveform", "Histogram"]
|
| 156 |
+
|
| 157 |
+
# Create visualizations
|
| 158 |
+
viz_html = ECGAnalyzer.create_all_visualizations(df, leads, viz_types)
|
| 159 |
+
|
| 160 |
+
# Generate statistics
|
| 161 |
+
stats = ECGAnalyzer.generate_statistics(df, leads)
|
| 162 |
+
|
| 163 |
+
summary = {
|
| 164 |
+
"task": "ECG Visualization",
|
| 165 |
+
"samples": len(df),
|
| 166 |
+
"leads_analyzed": leads,
|
| 167 |
+
"visualizations": viz_types,
|
| 168 |
+
"statistics": stats
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
return "✓ ECG visualization created", {"visualization": viz_html, "summary": summary}
|
| 172 |
+
except Exception as e:
|
| 173 |
+
return f"✗ Error: {str(e)}", {"visualization": None, "summary": None}
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def execute_statistical_summary(df: pd.DataFrame, leads: List[str] = None) -> Tuple[str, Dict[str, Any]]:
|
| 177 |
+
"""Execute statistical summary using ECGAnalyzer"""
|
| 178 |
+
try:
|
| 179 |
+
# Detect available leads
|
| 180 |
+
available_leads = ECGAnalyzer.detect_leads(df)
|
| 181 |
+
|
| 182 |
+
# Use provided leads or default to all available
|
| 183 |
+
if not leads:
|
| 184 |
+
leads = available_leads if available_leads else list(df.select_dtypes(include=[np.number]).columns)
|
| 185 |
+
|
| 186 |
+
if not leads:
|
| 187 |
+
return "⚠ No numeric columns found", {"summary": None, "visualization": None}
|
| 188 |
+
|
| 189 |
+
# Generate statistics
|
| 190 |
+
stats = ECGAnalyzer.generate_statistics(df, leads)
|
| 191 |
+
|
| 192 |
+
# Create HTML table for statistics
|
| 193 |
+
html_rows = []
|
| 194 |
+
html_rows.append("<table class='preview-table' style='margin: 20px auto; max-width: 900px;'>")
|
| 195 |
+
html_rows.append("<thead><tr><th>Lead</th><th>Mean</th><th>Std</th><th>Min</th><th>Q25</th><th>Median</th><th>Q75</th><th>Max</th></tr></thead>")
|
| 196 |
+
html_rows.append("<tbody>")
|
| 197 |
+
|
| 198 |
+
for lead, lead_stats in stats.items():
|
| 199 |
+
html_rows.append(f"<tr>")
|
| 200 |
+
html_rows.append(f"<td><strong>{lead}</strong></td>")
|
| 201 |
+
html_rows.append(f"<td>{lead_stats['mean']:.4f}</td>")
|
| 202 |
+
html_rows.append(f"<td>{lead_stats['std']:.4f}</td>")
|
| 203 |
+
html_rows.append(f"<td>{lead_stats['min']:.4f}</td>")
|
| 204 |
+
html_rows.append(f"<td>{lead_stats['q25']:.4f}</td>")
|
| 205 |
+
html_rows.append(f"<td>{lead_stats['median']:.4f}</td>")
|
| 206 |
+
html_rows.append(f"<td>{lead_stats['q75']:.4f}</td>")
|
| 207 |
+
html_rows.append(f"<td>{lead_stats['max']:.4f}</td>")
|
| 208 |
+
html_rows.append(f"</tr>")
|
| 209 |
+
|
| 210 |
+
html_rows.append("</tbody></table>")
|
| 211 |
+
summary_html = f"<div style='overflow-x:auto;'><h3 style='text-align:center;'>Statistical Summary</h3>{''.join(html_rows)}</div>"
|
| 212 |
+
|
| 213 |
+
summary = {
|
| 214 |
+
"task": "Statistical Summary",
|
| 215 |
+
"rows": len(df),
|
| 216 |
+
"leads_analyzed": leads,
|
| 217 |
+
"statistics": stats
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
return "✓ Statistical summary generated", {"summary": summary, "visualization": summary_html}
|
| 221 |
+
except Exception as e:
|
| 222 |
+
return f"✗ Error: {str(e)}", {"summary": None, "visualization": None}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ============================================================================
|
| 226 |
+
# UI MANAGER - Handles all UI state and updates
|
| 227 |
+
# ============================================================================
|
| 228 |
+
|
| 229 |
+
class UIManager:
|
| 230 |
+
"""Manages UI state and dynamic updates"""
|
| 231 |
+
|
| 232 |
+
def __init__(self):
|
| 233 |
+
self.current_df = None
|
| 234 |
+
self.current_data_type = "EHR Data"
|
| 235 |
+
self.chatbot_context = {}
|
| 236 |
+
self.command_map = {
|
| 237 |
+
"data_type": {"ehr": "EHR Data", "ecg": "ECG Data"},
|
| 238 |
+
"task": {
|
| 239 |
+
"deduplication": "Near-Duplicate Detection", "mislabeled": "Find Mislabeled Data",
|
| 240 |
+
"visualize_ecg": "ECG Visualization", "stats": "Statistical Summary"
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
|
| 245 |
+
"""Load CSV file"""
|
| 246 |
+
if file is None: return "⚠ No file uploaded", None
|
| 247 |
+
try:
|
| 248 |
+
df = pd.read_csv(file.name)
|
| 249 |
+
self.current_df = df
|
| 250 |
+
return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return f"✗ Error: {str(e)}", None
|
| 253 |
+
|
| 254 |
+
def on_file_upload(self, file, data_type: str):
|
| 255 |
+
"""Handle file upload - returns updates for all components"""
|
| 256 |
+
status, df = self.load_csv(file)
|
| 257 |
+
if df is None:
|
| 258 |
+
return (
|
| 259 |
+
status, gr.update(value=None), gr.update(choices=[], value=[]),
|
| 260 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 261 |
+
gr.update(choices=[]), gr.update(choices=[]),
|
| 262 |
+
gr.update(choices=[]), gr.update(choices=[], value=[]),
|
| 263 |
+
gr.update(choices=[]), gr.update(choices=[], value=[]),
|
| 264 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
|
| 268 |
+
available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
|
| 269 |
+
col_choices = list(df.columns)
|
| 270 |
+
|
| 271 |
+
# Detect ECG leads if ECG data
|
| 272 |
+
ecg_leads = ECGAnalyzer.detect_leads(df) if data_type == "ECG Data" else []
|
| 273 |
+
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
|
| 274 |
+
|
| 275 |
+
return (
|
| 276 |
+
status, gr.update(value=df.head(200)), gr.update(choices=available_tasks, value=[], interactive=True),
|
| 277 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 278 |
+
gr.update(choices=col_choices, value=None), gr.update(choices=col_choices, value=None),
|
| 279 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
|
| 280 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
|
| 281 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def on_data_type_change(self, data_type: str, file):
|
| 285 |
+
"""Handle data type change"""
|
| 286 |
+
self.current_data_type = data_type
|
| 287 |
+
if file and self.current_df is not None:
|
| 288 |
+
self.chatbot_context["type"] = data_type
|
| 289 |
+
|
| 290 |
+
# Update ECG lead choices if switching to ECG data
|
| 291 |
+
ecg_leads = ECGAnalyzer.detect_leads(self.current_df) if data_type == "ECG Data" else []
|
| 292 |
+
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
|
| 293 |
+
|
| 294 |
+
return (
|
| 295 |
+
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
|
| 296 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 297 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
|
| 298 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
|
| 299 |
+
f"Data type changed to: {data_type}"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return (
|
| 303 |
+
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
|
| 304 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 305 |
+
gr.update(), gr.update(), gr.update(), gr.update(),
|
| 306 |
+
f"Data type changed to: {data_type}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def on_tasks_change(self, selected_tasks: List[str]):
|
| 310 |
+
"""Handle task selection change - show/hide parameter groups"""
|
| 311 |
+
show_ndd = "Near-Duplicate Detection" in selected_tasks
|
| 312 |
+
show_mislabel = "Find Mislabeled Data" in selected_tasks
|
| 313 |
+
show_ecg_viz = "ECG Visualization" in selected_tasks
|
| 314 |
+
show_ecg_stats = "Statistical Summary" in selected_tasks and self.current_data_type == "ECG Data"
|
| 315 |
+
return (
|
| 316 |
+
gr.update(visible=show_ndd),
|
| 317 |
+
gr.update(visible=show_mislabel),
|
| 318 |
+
gr.update(visible=show_ecg_viz),
|
| 319 |
+
gr.update(visible=show_ecg_stats)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def process_analysis(self, file, data_type: str, selected_tasks: List[str],
|
| 323 |
+
ndd_label: str, mislabel_label: str,
|
| 324 |
+
ecg_viz_leads: List[str], ecg_viz_types: List[str],
|
| 325 |
+
ecg_stats_leads: List[str]):
|
| 326 |
+
"""Process analysis tasks based on UI inputs."""
|
| 327 |
+
status, df = self.load_csv(file)
|
| 328 |
+
if df is None:
|
| 329 |
+
return (status, None, None, None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
|
| 330 |
+
|
| 331 |
+
params = {
|
| 332 |
+
"ndd_label": ndd_label,
|
| 333 |
+
"mislabel_label": mislabel_label,
|
| 334 |
+
"ecg_viz_leads": ecg_viz_leads,
|
| 335 |
+
"ecg_viz_types": ecg_viz_types,
|
| 336 |
+
"ecg_stats_leads": ecg_stats_leads
|
| 337 |
+
}
|
| 338 |
+
return self._run_analysis(df, data_type, selected_tasks, params)
|
| 339 |
+
|
| 340 |
+
def _run_analysis(self, df: pd.DataFrame, data_type: str, selected_tasks: List[str], params: Dict[str, Any]):
|
| 341 |
+
"""Centralized analysis executor, callable from UI or chatbot."""
|
| 342 |
+
if not selected_tasks:
|
| 343 |
+
return (
|
| 344 |
+
"⚠ No tasks selected", df.head(200), None, None, None,
|
| 345 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
all_tabs = set(); all_results = {"original": df.head(200), "processed": None, "summary": [], "visualization": ""}
|
| 349 |
+
status_messages = []; executor = AnalysisExecutor()
|
| 350 |
+
|
| 351 |
+
for task_name in selected_tasks:
|
| 352 |
+
config = TaskRegistry.get_config(data_type, task_name)
|
| 353 |
+
if not config:
|
| 354 |
+
status_messages.append(f"✗ Unknown task: {task_name}"); continue
|
| 355 |
+
all_tabs.update(config.output_tabs)
|
| 356 |
+
|
| 357 |
+
if task_name == "Near-Duplicate Detection":
|
| 358 |
+
status_msg, results = executor.execute_near_duplicate_detection(df, params.get("ndd_label"))
|
| 359 |
+
elif task_name == "Find Mislabeled Data":
|
| 360 |
+
status_msg, results = executor.execute_find_mislabeled(df, params.get("mislabel_label"))
|
| 361 |
+
elif task_name == "ECG Visualization":
|
| 362 |
+
status_msg, results = executor.execute_ecg_visualization(
|
| 363 |
+
df,
|
| 364 |
+
params.get("ecg_viz_leads"),
|
| 365 |
+
params.get("ecg_viz_types")
|
| 366 |
+
)
|
| 367 |
+
elif task_name == "Statistical Summary":
|
| 368 |
+
if data_type == "ECG Data":
|
| 369 |
+
status_msg, results = executor.execute_statistical_summary(df, params.get("ecg_stats_leads"))
|
| 370 |
+
else:
|
| 371 |
+
status_msg, results = executor.execute_statistical_summary(df)
|
| 372 |
+
else:
|
| 373 |
+
status_msg, results = "✗ Task not implemented", {}
|
| 374 |
+
|
| 375 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 376 |
+
if results.get("processed") is not None: all_results["processed"] = results["processed"]
|
| 377 |
+
if results.get("visualization"): all_results["visualization"] += results["visualization"]
|
| 378 |
+
if results.get("summary") is not None: all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 379 |
+
|
| 380 |
+
self.chatbot_context["summary"] = all_results["summary"] or None
|
| 381 |
+
self.chatbot_context["visualization"] = all_results["visualization"] or None
|
| 382 |
+
|
| 383 |
+
return (
|
| 384 |
+
"\n".join(status_messages), all_results["original"], all_results["processed"],
|
| 385 |
+
all_results["summary"] or None, all_results["visualization"] or None,
|
| 386 |
+
gr.update(visible="original" in all_tabs), gr.update(visible="processed" in all_tabs),
|
| 387 |
+
gr.update(visible="summary" in all_tabs), gr.update(visible="visualization" in all_tabs)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def chatbot_respond(self, message: str, history: List):
|
| 391 |
+
"""Handle chatbot messages, parsing for commands or responding to queries."""
|
| 392 |
+
history = history or []; df = self.chatbot_context.get("df")
|
| 393 |
+
|
| 394 |
+
summary = json.dumps(self.chatbot_context.get("summary")[-1]) if self.chatbot_context.get("summary") else ""
|
| 395 |
+
|
| 396 |
+
# visualization = self.chatbot_context.get("visualization")
|
| 397 |
+
visualization = ''
|
| 398 |
+
|
| 399 |
+
print("history:", history)
|
| 400 |
+
print("# ============================================================================\n ")
|
| 401 |
+
print("message:", message)
|
| 402 |
+
print("# ============================================================================\n ")
|
| 403 |
+
print("summary:", summary)
|
| 404 |
+
print("# ============================================================================\n ")
|
| 405 |
+
print("visualization:", visualization)
|
| 406 |
+
print("# ============================================================================\n ")
|
| 407 |
+
|
| 408 |
+
ui_updates = tuple([gr.update()] * 9) # status, 4 outputs, 4 tabs
|
| 409 |
+
|
| 410 |
+
command = message
|
| 411 |
+
context = summary + visualization if summary or visualization else ""
|
| 412 |
+
response = simple_chat(command, context)
|
| 413 |
+
|
| 414 |
+
history.append((message, response))
|
| 415 |
+
return (history, "") + ui_updates
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# ============================================================================
|
| 419 |
+
# GRADIO INTERFACE
|
| 420 |
+
# ============================================================================
|
| 421 |
+
|
| 422 |
+
def create_interface():
|
| 423 |
+
"""Build the Gradio interface"""
|
| 424 |
+
ui_manager = UIManager()
|
| 425 |
+
custom_css = """
|
| 426 |
+
* { box-sizing: border-box; } html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
|
| 427 |
+
.gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
|
| 428 |
+
#app-container { height: 100vh; display: flex; flex-direction: column; padding: 0.75rem; gap: 0.75rem; }
|
| 429 |
+
#main-row { flex: 1; min-height: 0; display: flex; gap: 0.75rem; }
|
| 430 |
+
#left-panel { display: flex; flex-direction: column; height: 100%; background: #f9fafb; border-radius: 10px; padding: 0.75rem; gap: 0.5rem; }
|
| 431 |
+
#task-section { flex: 1; min-height: 0; overflow-y: auto; display: flex; flex-direction: column; gap: 0.5rem; }
|
| 432 |
+
#middle-panel, #chat-panel { display: flex; flex-direction: column; height: 100%; }
|
| 433 |
+
#tabs-container { flex: 1; min-height: 0; display: flex; flex-direction: column; }
|
| 434 |
+
#tabs-container .tabitem { flex: 1; min-height: 0; overflow: auto; }
|
| 435 |
+
#chat-history { flex: 1; min-height: 0; overflow-y: auto; margin-bottom: 0.5rem; }
|
| 436 |
+
#chat-input-row { flex-shrink: 0; display: flex; gap: 0.5rem; }
|
| 437 |
+
.preview-table { border-collapse: collapse; width: 100%; font-size: 0.875rem; }
|
| 438 |
+
.preview-table th { background-color: #3498db; color: white; padding: 8px; text-align: left; position: sticky; top: 0; }
|
| 439 |
+
"""
|
| 440 |
+
|
| 441 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
|
| 442 |
+
with gr.Column(elem_id="app-container"):
|
| 443 |
+
gr.Markdown("# 🏥 Medical Data Analysis Platform")
|
| 444 |
+
with gr.Row():
|
| 445 |
+
file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
|
| 446 |
+
data_type = gr.Dropdown(choices=["EHR Data", "ECG Data"], value="EHR Data", label="Data Type", scale=1)
|
| 447 |
+
|
| 448 |
+
with gr.Row(elem_id="main-row"):
|
| 449 |
+
with gr.Column(scale=2, elem_id="left-panel"):
|
| 450 |
+
with gr.Group(elem_id="task-section"):
|
| 451 |
+
gr.Markdown("#### Analysis Tasks")
|
| 452 |
+
task_selector = gr.CheckboxGroup(choices=TaskRegistry.get_tasks_for_data_type("EHR Data"), label=None)
|
| 453 |
+
with gr.Group(visible=False) as ndd_param_group:
|
| 454 |
+
gr.Markdown("**Near-Duplicate Detection Parameters**")
|
| 455 |
+
ndd_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
|
| 456 |
+
with gr.Group(visible=False) as mislabel_param_group:
|
| 457 |
+
gr.Markdown("**Find Mislabeled Data Parameters**")
|
| 458 |
+
mislabel_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
|
| 459 |
+
with gr.Group(visible=False) as ecg_viz_param_group:
|
| 460 |
+
gr.Markdown("**ECG Visualization Parameters**")
|
| 461 |
+
ecg_viz_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
|
| 462 |
+
ecg_viz_types = gr.CheckboxGroup(
|
| 463 |
+
choices=["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"],
|
| 464 |
+
label="Visualization Types",
|
| 465 |
+
value=["Signal Waveform", "Histogram"]
|
| 466 |
+
)
|
| 467 |
+
with gr.Group(visible=False) as ecg_stats_param_group:
|
| 468 |
+
gr.Markdown("**Statistical Summary Parameters**")
|
| 469 |
+
ecg_stats_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
|
| 470 |
+
process_btn = gr.Button("▶ Process", variant="primary")
|
| 471 |
+
status_output = gr.Textbox(label="Status", interactive=False, lines=2)
|
| 472 |
+
|
| 473 |
+
with gr.Column(scale=7, elem_id="middle-panel"):
|
| 474 |
+
with gr.Tabs(elem_id="tabs-container"):
|
| 475 |
+
with gr.TabItem("Original Data", visible=False) as tab_original:
|
| 476 |
+
original_df_output = gr.DataFrame(interactive=False)
|
| 477 |
+
with gr.TabItem("Processed Data", visible=False) as tab_processed:
|
| 478 |
+
processed_df_output = gr.DataFrame(interactive=False)
|
| 479 |
+
with gr.TabItem("Summary", visible=False) as tab_summary:
|
| 480 |
+
summary_output = gr.JSON()
|
| 481 |
+
with gr.TabItem("Visualization", visible=False) as tab_viz:
|
| 482 |
+
viz_output = gr.HTML()
|
| 483 |
+
|
| 484 |
+
with gr.Column(scale=3, elem_id="chat-panel"):
|
| 485 |
+
gr.Markdown("### 💬 AI Assistant")
|
| 486 |
+
chatbot = gr.Chatbot(elem_id="chat-history", height="100%")
|
| 487 |
+
with gr.Row(elem_id="chat-input-row"):
|
| 488 |
+
msg_input = gr.Textbox(placeholder="Ask or send a JSON command...", scale=4, container=False)
|
| 489 |
+
send_btn = gr.Button("Send", scale=1)
|
| 490 |
+
|
| 491 |
+
analysis_outputs = [
|
| 492 |
+
status_output, original_df_output, processed_df_output, summary_output, viz_output,
|
| 493 |
+
tab_original, tab_processed, tab_summary, tab_viz
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
file_input.change(
|
| 497 |
+
fn=ui_manager.on_file_upload, inputs=[file_input, data_type],
|
| 498 |
+
outputs=[status_output, original_df_output, task_selector,
|
| 499 |
+
ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
|
| 500 |
+
ndd_label_dropdown, mislabel_label_dropdown,
|
| 501 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types,
|
| 502 |
+
tab_original, tab_processed, tab_summary, tab_viz]
|
| 503 |
+
)
|
| 504 |
+
data_type.change(
|
| 505 |
+
fn=ui_manager.on_data_type_change, inputs=[data_type, file_input],
|
| 506 |
+
outputs=[task_selector, ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
|
| 507 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, status_output]
|
| 508 |
+
)
|
| 509 |
+
task_selector.change(
|
| 510 |
+
fn=ui_manager.on_tasks_change, inputs=[task_selector],
|
| 511 |
+
outputs=[ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group]
|
| 512 |
+
)
|
| 513 |
+
process_btn.click(
|
| 514 |
+
fn=ui_manager.process_analysis,
|
| 515 |
+
inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown,
|
| 516 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads],
|
| 517 |
+
outputs=analysis_outputs
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
chat_submit_args = {"fn": ui_manager.chatbot_respond, "inputs": [msg_input, chatbot], "outputs": [chatbot, msg_input] + analysis_outputs}
|
| 521 |
+
send_btn.click(**chat_submit_args)
|
| 522 |
+
msg_input.submit(**chat_submit_args)
|
| 523 |
+
|
| 524 |
+
return demo
|
| 525 |
+
|
| 526 |
+
# ============================================================================
|
| 527 |
+
# LAUNCH
|
| 528 |
+
# ============================================================================
|
| 529 |
+
|
| 530 |
+
if __name__ == "__main__":
|
| 531 |
+
demo = create_interface()
|
| 532 |
+
demo.launch(share=False, server_name="0.0.0.0", server_port=7890)
|
app_studio.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 5 |
+
import io
|
| 6 |
+
import base64
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from pipeline.deduplication import find_near_duplicates
|
| 13 |
+
from pipeline.featurizer import custom_featurizer
|
| 14 |
+
from pipeline.issues import find_issues
|
| 15 |
+
from pipeline.pipeline import make_step, run_pipeline
|
| 16 |
+
|
| 17 |
+
from ecg_analyzer import ECGAnalyzer
|
| 18 |
+
|
| 19 |
+
# ============================================================================
|
| 20 |
+
# ANALYSIS TASK CONFIGURATION
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TaskConfig:
|
| 25 |
+
"""Configuration for each analysis task"""
|
| 26 |
+
name: str
|
| 27 |
+
data_type: str
|
| 28 |
+
requires_params: bool
|
| 29 |
+
param_components: List[Dict[str, Any]]
|
| 30 |
+
output_tabs: List[str]
|
| 31 |
+
|
| 32 |
+
class TaskRegistry:
|
| 33 |
+
"""Registry mapping tasks to their configurations"""
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_config(data_type: str, task_name: str) -> Optional[TaskConfig]:
|
| 37 |
+
"""Get configuration for a specific task"""
|
| 38 |
+
configs = {
|
| 39 |
+
"EHR Data": {
|
| 40 |
+
"Near-Duplicate Detection": TaskConfig(
|
| 41 |
+
name="Near-Duplicate Detection",
|
| 42 |
+
data_type="EHR Data",
|
| 43 |
+
requires_params=True,
|
| 44 |
+
param_components=[
|
| 45 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
|
| 46 |
+
],
|
| 47 |
+
output_tabs=["original", "processed", "summary"]
|
| 48 |
+
),
|
| 49 |
+
"Find Mislabeled Data": TaskConfig(
|
| 50 |
+
name="Find Mislabeled Data",
|
| 51 |
+
data_type="EHR Data",
|
| 52 |
+
requires_params=True,
|
| 53 |
+
param_components=[
|
| 54 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
|
| 55 |
+
],
|
| 56 |
+
output_tabs=["original", "summary"]
|
| 57 |
+
)
|
| 58 |
+
},
|
| 59 |
+
"ECG Data": {
|
| 60 |
+
"ECG Visualization": TaskConfig(
|
| 61 |
+
name="ECG Visualization",
|
| 62 |
+
data_type="ECG Data",
|
| 63 |
+
requires_params=True,
|
| 64 |
+
param_components=[
|
| 65 |
+
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_leads"},
|
| 66 |
+
{"type": "checkboxgroup", "label": "Visualization Types", "elem_id": "ecg_viz_types"}
|
| 67 |
+
],
|
| 68 |
+
output_tabs=["visualization", "summary"]
|
| 69 |
+
),
|
| 70 |
+
"Statistical Summary": TaskConfig(
|
| 71 |
+
name="Statistical Summary",
|
| 72 |
+
data_type="ECG Data",
|
| 73 |
+
requires_params=True,
|
| 74 |
+
param_components=[
|
| 75 |
+
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_stats_leads"}
|
| 76 |
+
],
|
| 77 |
+
output_tabs=["summary", "visualization"]
|
| 78 |
+
)
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
return configs.get(data_type, {}).get(task_name)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def get_tasks_for_data_type(data_type: str) -> List[str]:
|
| 85 |
+
"""Get available tasks for a data type"""
|
| 86 |
+
tasks = {
|
| 87 |
+
"EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
|
| 88 |
+
"ECG Data": ["ECG Visualization", "Statistical Summary"]
|
| 89 |
+
}
|
| 90 |
+
return tasks.get(data_type, [])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ============================================================================
|
| 94 |
+
# ANALYSIS EXECUTION
|
| 95 |
+
# ============================================================================
|
| 96 |
+
|
| 97 |
+
class AnalysisExecutor:
|
| 98 |
+
"""Executes analysis tasks and returns results"""
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 102 |
+
"""Execute near-duplicate detection pipeline"""
|
| 103 |
+
try:
|
| 104 |
+
if not label:
|
| 105 |
+
return "⚠ Label column required", {"original": df, "processed": None, "summary": None}
|
| 106 |
+
|
| 107 |
+
bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
|
| 108 |
+
steps = [
|
| 109 |
+
make_step(find_near_duplicates, name="dedup")(progress=bar),
|
| 110 |
+
make_step(custom_featurizer, name="featurize")(
|
| 111 |
+
label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
|
| 112 |
+
),
|
| 113 |
+
make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
|
| 114 |
+
]
|
| 115 |
+
results_df, summary_list = run_pipeline(steps, df=df)
|
| 116 |
+
bar.close()
|
| 117 |
+
|
| 118 |
+
return "✓ Near-duplicate detection completed", {
|
| 119 |
+
"original": df, "processed": results_df, "summary": summary_list
|
| 120 |
+
}
|
| 121 |
+
except Exception as e:
|
| 122 |
+
return f"✗ Error: {str(e)}", {"original": df, "processed": None, "summary": None}
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 126 |
+
"""Execute mislabeled data detection"""
|
| 127 |
+
try:
|
| 128 |
+
if not label:
|
| 129 |
+
return "⚠ Label column required", {"original": df, "summary": None}
|
| 130 |
+
|
| 131 |
+
summary = {
|
| 132 |
+
"task": "Find Mislabeled Data", "label_column": label, "total_samples": len(df),
|
| 133 |
+
"suspicious_samples": 0, "message": "Mislabeled detection analysis completed"
|
| 134 |
+
}
|
| 135 |
+
return "✓ Mislabeled data analysis completed", {"original": df, "summary": summary}
|
| 136 |
+
except Exception as e:
|
| 137 |
+
return f"✗ Error: {str(e)}", {"original": df, "summary": None}
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def execute_ecg_visualization(df: pd.DataFrame, leads: List[str] = None, viz_types: List[str] = None) -> Tuple[str, Dict[str, Any]]:
|
| 141 |
+
"""Execute ECG visualization using ECGAnalyzer"""
|
| 142 |
+
try:
|
| 143 |
+
# Detect available leads
|
| 144 |
+
available_leads = ECGAnalyzer.detect_leads(df)
|
| 145 |
+
|
| 146 |
+
# Use provided leads or default to all available
|
| 147 |
+
if not leads:
|
| 148 |
+
leads = available_leads if available_leads else []
|
| 149 |
+
|
| 150 |
+
if not leads:
|
| 151 |
+
return "⚠ No ECG leads found in data", {"visualization": None, "summary": None}
|
| 152 |
+
|
| 153 |
+
# Default visualization types
|
| 154 |
+
if not viz_types:
|
| 155 |
+
viz_types = ["Signal Waveform", "Histogram"]
|
| 156 |
+
|
| 157 |
+
# Create visualizations
|
| 158 |
+
viz_html = ECGAnalyzer.create_all_visualizations(df, leads, viz_types)
|
| 159 |
+
|
| 160 |
+
# Generate statistics
|
| 161 |
+
stats = ECGAnalyzer.generate_statistics(df, leads)
|
| 162 |
+
|
| 163 |
+
summary = {
|
| 164 |
+
"task": "ECG Visualization",
|
| 165 |
+
"samples": len(df),
|
| 166 |
+
"leads_analyzed": leads,
|
| 167 |
+
"visualizations": viz_types,
|
| 168 |
+
"statistics": stats
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
return "✓ ECG visualization created", {"visualization": viz_html, "summary": summary}
|
| 172 |
+
except Exception as e:
|
| 173 |
+
return f"✗ Error: {str(e)}", {"visualization": None, "summary": None}
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def execute_statistical_summary(df: pd.DataFrame, leads: List[str] = None) -> Tuple[str, Dict[str, Any]]:
|
| 177 |
+
"""Execute statistical summary using ECGAnalyzer"""
|
| 178 |
+
try:
|
| 179 |
+
# Detect available leads
|
| 180 |
+
available_leads = ECGAnalyzer.detect_leads(df)
|
| 181 |
+
|
| 182 |
+
# Use provided leads or default to all available
|
| 183 |
+
if not leads:
|
| 184 |
+
leads = available_leads if available_leads else list(df.select_dtypes(include=[np.number]).columns)
|
| 185 |
+
|
| 186 |
+
if not leads:
|
| 187 |
+
return "⚠ No numeric columns found", {"summary": None, "visualization": None}
|
| 188 |
+
|
| 189 |
+
# Generate statistics
|
| 190 |
+
stats = ECGAnalyzer.generate_statistics(df, leads)
|
| 191 |
+
|
| 192 |
+
# Create HTML table for statistics
|
| 193 |
+
html_rows = []
|
| 194 |
+
html_rows.append("<table class='preview-table' style='margin: 20px auto; max-width: 900px;'>")
|
| 195 |
+
html_rows.append("<thead><tr><th>Lead</th><th>Mean</th><th>Std</th><th>Min</th><th>Q25</th><th>Median</th><th>Q75</th><th>Max</th></tr></thead>")
|
| 196 |
+
html_rows.append("<tbody>")
|
| 197 |
+
|
| 198 |
+
for lead, lead_stats in stats.items():
|
| 199 |
+
html_rows.append(f"<tr>")
|
| 200 |
+
html_rows.append(f"<td><strong>{lead}</strong></td>")
|
| 201 |
+
html_rows.append(f"<td>{lead_stats['mean']:.4f}</td>")
|
| 202 |
+
html_rows.append(f"<td>{lead_stats['std']:.4f}</td>")
|
| 203 |
+
html_rows.append(f"<td>{lead_stats['min']:.4f}</td>")
|
| 204 |
+
html_rows.append(f"<td>{lead_stats['q25']:.4f}</td>")
|
| 205 |
+
html_rows.append(f"<td>{lead_stats['median']:.4f}</td>")
|
| 206 |
+
html_rows.append(f"<td>{lead_stats['q75']:.4f}</td>")
|
| 207 |
+
html_rows.append(f"<td>{lead_stats['max']:.4f}</td>")
|
| 208 |
+
html_rows.append(f"</tr>")
|
| 209 |
+
|
| 210 |
+
html_rows.append("</tbody></table>")
|
| 211 |
+
summary_html = f"<div style='overflow-x:auto;'><h3 style='text-align:center;'>Statistical Summary</h3>{''.join(html_rows)}</div>"
|
| 212 |
+
|
| 213 |
+
summary = {
|
| 214 |
+
"task": "Statistical Summary",
|
| 215 |
+
"rows": len(df),
|
| 216 |
+
"leads_analyzed": leads,
|
| 217 |
+
"statistics": stats
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
return "✓ Statistical summary generated", {"summary": summary, "visualization": summary_html}
|
| 221 |
+
except Exception as e:
|
| 222 |
+
return f"✗ Error: {str(e)}", {"summary": None, "visualization": None}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ============================================================================
|
| 226 |
+
# UI MANAGER - Handles all UI state and updates
|
| 227 |
+
# ============================================================================
|
| 228 |
+
|
| 229 |
+
class UIManager:
|
| 230 |
+
"""Manages UI state and dynamic updates"""
|
| 231 |
+
|
| 232 |
+
def __init__(self):
|
| 233 |
+
self.current_df = None
|
| 234 |
+
self.current_data_type = "EHR Data"
|
| 235 |
+
self.chatbot_context = {}
|
| 236 |
+
self.command_map = {
|
| 237 |
+
"data_type": {"ehr": "EHR Data", "ecg": "ECG Data"},
|
| 238 |
+
"task": {
|
| 239 |
+
"deduplication": "Near-Duplicate Detection", "mislabeled": "Find Mislabeled Data",
|
| 240 |
+
"visualize_ecg": "ECG Visualization", "stats": "Statistical Summary"
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
|
| 245 |
+
"""Load CSV file"""
|
| 246 |
+
if file is None: return "⚠ No file uploaded", None
|
| 247 |
+
try:
|
| 248 |
+
df = pd.read_csv(file.name)
|
| 249 |
+
self.current_df = df
|
| 250 |
+
return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return f"✗ Error: {str(e)}", None
|
| 253 |
+
|
| 254 |
+
def on_file_upload(self, file, data_type: str):
|
| 255 |
+
"""Handle file upload - returns updates for all components"""
|
| 256 |
+
status, df = self.load_csv(file)
|
| 257 |
+
if df is None:
|
| 258 |
+
return (
|
| 259 |
+
status, gr.update(value=None), gr.update(choices=[], value=[]),
|
| 260 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 261 |
+
gr.update(choices=[]), gr.update(choices=[]),
|
| 262 |
+
gr.update(choices=[]), gr.update(choices=[], value=[]),
|
| 263 |
+
gr.update(choices=[]), gr.update(choices=[], value=[]),
|
| 264 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
|
| 268 |
+
available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
|
| 269 |
+
col_choices = list(df.columns)
|
| 270 |
+
|
| 271 |
+
# Detect ECG leads if ECG data
|
| 272 |
+
ecg_leads = ECGAnalyzer.detect_leads(df) if data_type == "ECG Data" else []
|
| 273 |
+
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
|
| 274 |
+
|
| 275 |
+
return (
|
| 276 |
+
status, gr.update(value=df.head(200)), gr.update(choices=available_tasks, value=[], interactive=True),
|
| 277 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 278 |
+
gr.update(choices=col_choices, value=None), gr.update(choices=col_choices, value=None),
|
| 279 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
|
| 280 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
|
| 281 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def on_data_type_change(self, data_type: str, file):
|
| 285 |
+
"""Handle data type change"""
|
| 286 |
+
self.current_data_type = data_type
|
| 287 |
+
if file and self.current_df is not None:
|
| 288 |
+
self.chatbot_context["type"] = data_type
|
| 289 |
+
|
| 290 |
+
# Update ECG lead choices if switching to ECG data
|
| 291 |
+
ecg_leads = ECGAnalyzer.detect_leads(self.current_df) if data_type == "ECG Data" else []
|
| 292 |
+
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
|
| 293 |
+
|
| 294 |
+
return (
|
| 295 |
+
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
|
| 296 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 297 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
|
| 298 |
+
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
|
| 299 |
+
f"Data type changed to: {data_type}"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return (
|
| 303 |
+
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
|
| 304 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 305 |
+
gr.update(), gr.update(), gr.update(), gr.update(),
|
| 306 |
+
f"Data type changed to: {data_type}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def on_tasks_change(self, selected_tasks: List[str]):
|
| 310 |
+
"""Handle task selection change - show/hide parameter groups"""
|
| 311 |
+
show_ndd = "Near-Duplicate Detection" in selected_tasks
|
| 312 |
+
show_mislabel = "Find Mislabeled Data" in selected_tasks
|
| 313 |
+
show_ecg_viz = "ECG Visualization" in selected_tasks
|
| 314 |
+
show_ecg_stats = "Statistical Summary" in selected_tasks and self.current_data_type == "ECG Data"
|
| 315 |
+
return (
|
| 316 |
+
gr.update(visible=show_ndd),
|
| 317 |
+
gr.update(visible=show_mislabel),
|
| 318 |
+
gr.update(visible=show_ecg_viz),
|
| 319 |
+
gr.update(visible=show_ecg_stats)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def process_analysis(self, file, data_type: str, selected_tasks: List[str],
|
| 323 |
+
ndd_label: str, mislabel_label: str,
|
| 324 |
+
ecg_viz_leads: List[str], ecg_viz_types: List[str],
|
| 325 |
+
ecg_stats_leads: List[str]):
|
| 326 |
+
"""Process analysis tasks based on UI inputs."""
|
| 327 |
+
status, df = self.load_csv(file)
|
| 328 |
+
if df is None:
|
| 329 |
+
return (status, None, None, None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
|
| 330 |
+
|
| 331 |
+
params = {
|
| 332 |
+
"ndd_label": ndd_label,
|
| 333 |
+
"mislabel_label": mislabel_label,
|
| 334 |
+
"ecg_viz_leads": ecg_viz_leads,
|
| 335 |
+
"ecg_viz_types": ecg_viz_types,
|
| 336 |
+
"ecg_stats_leads": ecg_stats_leads
|
| 337 |
+
}
|
| 338 |
+
return self._run_analysis(df, data_type, selected_tasks, params)
|
| 339 |
+
|
| 340 |
+
def _run_analysis(self, df: pd.DataFrame, data_type: str, selected_tasks: List[str], params: Dict[str, Any]):
|
| 341 |
+
"""Centralized analysis executor, callable from UI or chatbot."""
|
| 342 |
+
if not selected_tasks:
|
| 343 |
+
return (
|
| 344 |
+
"⚠ No tasks selected", df.head(200), None, None, None,
|
| 345 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
all_tabs = set(); all_results = {"original": df.head(200), "processed": None, "summary": [], "visualization": ""}
|
| 349 |
+
status_messages = []; executor = AnalysisExecutor()
|
| 350 |
+
|
| 351 |
+
for task_name in selected_tasks:
|
| 352 |
+
config = TaskRegistry.get_config(data_type, task_name)
|
| 353 |
+
if not config:
|
| 354 |
+
status_messages.append(f"✗ Unknown task: {task_name}"); continue
|
| 355 |
+
all_tabs.update(config.output_tabs)
|
| 356 |
+
|
| 357 |
+
if task_name == "Near-Duplicate Detection":
|
| 358 |
+
status_msg, results = executor.execute_near_duplicate_detection(df, params.get("ndd_label"))
|
| 359 |
+
elif task_name == "Find Mislabeled Data":
|
| 360 |
+
status_msg, results = executor.execute_find_mislabeled(df, params.get("mislabel_label"))
|
| 361 |
+
elif task_name == "ECG Visualization":
|
| 362 |
+
status_msg, results = executor.execute_ecg_visualization(
|
| 363 |
+
df,
|
| 364 |
+
params.get("ecg_viz_leads"),
|
| 365 |
+
params.get("ecg_viz_types")
|
| 366 |
+
)
|
| 367 |
+
elif task_name == "Statistical Summary":
|
| 368 |
+
if data_type == "ECG Data":
|
| 369 |
+
status_msg, results = executor.execute_statistical_summary(df, params.get("ecg_stats_leads"))
|
| 370 |
+
else:
|
| 371 |
+
status_msg, results = executor.execute_statistical_summary(df)
|
| 372 |
+
else:
|
| 373 |
+
status_msg, results = "✗ Task not implemented", {}
|
| 374 |
+
|
| 375 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 376 |
+
if results.get("processed") is not None: all_results["processed"] = results["processed"]
|
| 377 |
+
if results.get("visualization"): all_results["visualization"] += results["visualization"]
|
| 378 |
+
if results.get("summary") is not None: all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 379 |
+
|
| 380 |
+
return (
|
| 381 |
+
"\n".join(status_messages), all_results["original"], all_results["processed"],
|
| 382 |
+
all_results["summary"] or None, all_results["visualization"] or None,
|
| 383 |
+
gr.update(visible="original" in all_tabs), gr.update(visible="processed" in all_tabs),
|
| 384 |
+
gr.update(visible="summary" in all_tabs), gr.update(visible="visualization" in all_tabs)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def chatbot_respond(self, message: str, history: List):
|
| 388 |
+
"""Handle chatbot messages, parsing for commands or responding to queries."""
|
| 389 |
+
history = history or []; df = self.chatbot_context.get("df")
|
| 390 |
+
ui_updates = tuple([gr.update()] * 9) # status, 4 outputs, 4 tabs
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
command = json.loads(message)
|
| 394 |
+
if isinstance(command, dict) and "task" in command:
|
| 395 |
+
if df is None:
|
| 396 |
+
history.append((message, "Please upload a data file before running a command."))
|
| 397 |
+
return (history, "") + ui_updates
|
| 398 |
+
|
| 399 |
+
data_type = self.command_map["data_type"].get(command.get("data_type", "").lower(), self.current_data_type)
|
| 400 |
+
task_name = self.command_map["task"].get(command.get("task", "").lower())
|
| 401 |
+
|
| 402 |
+
if not task_name:
|
| 403 |
+
response = f"Unknown task: '{command['task']}'. Valid: {list(self.command_map['task'].keys())}"
|
| 404 |
+
history.append((message, response))
|
| 405 |
+
return (history, "") + ui_updates
|
| 406 |
+
|
| 407 |
+
params = command.get("parameters", {})
|
| 408 |
+
analysis_params = {"ndd_label": params.get("label"), "mislabel_label": params.get("label")}
|
| 409 |
+
|
| 410 |
+
analysis_updates = self._run_analysis(df, data_type, [task_name], analysis_params)
|
| 411 |
+
history.append((message, f"Command executed: Running '{task_name}'."))
|
| 412 |
+
return (history, "") + analysis_updates
|
| 413 |
+
except (json.JSONDecodeError, TypeError):
|
| 414 |
+
pass # Not a JSON command, proceed with standard logic
|
| 415 |
+
|
| 416 |
+
if "column" in message.lower():
|
| 417 |
+
response = (f"Dataset has {len(df.columns)} columns: {', '.join(map(str, df.columns))}" if df is not None else "Please upload a file first.")
|
| 418 |
+
elif "row" in message.lower():
|
| 419 |
+
response = f"Dataset has {len(df)} rows." if df is not None else "Please upload a file first."
|
| 420 |
+
elif "help" in message.lower():
|
| 421 |
+
response = "Ask about 'columns' or 'rows'. To run a task, send JSON, e.g., `{\"task\": \"stats\"}` or `{\"task\": \"deduplication\", \"parameters\": {\"label\": \"your_column\"}}`"
|
| 422 |
+
else:
|
| 423 |
+
response = "I can help with data queries or run tasks via JSON commands. Try asking 'help'."
|
| 424 |
+
|
| 425 |
+
history.append((message, response))
|
| 426 |
+
return (history, "") + ui_updates
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# ============================================================================
|
| 430 |
+
# GRADIO INTERFACE
|
| 431 |
+
# ============================================================================
|
| 432 |
+
|
| 433 |
+
def create_interface():
|
| 434 |
+
"""Build the Gradio interface"""
|
| 435 |
+
ui_manager = UIManager()
|
| 436 |
+
custom_css = """
|
| 437 |
+
* { box-sizing: border-box; } html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
|
| 438 |
+
.gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
|
| 439 |
+
#app-container { height: 100vh; display: flex; flex-direction: column; padding: 0.75rem; gap: 0.75rem; }
|
| 440 |
+
#main-row { flex: 1; min-height: 0; display: flex; gap: 0.75rem; }
|
| 441 |
+
#left-panel { display: flex; flex-direction: column; height: 100%; background: #f9fafb; border-radius: 10px; padding: 0.75rem; gap: 0.5rem; }
|
| 442 |
+
#task-section { flex: 1; min-height: 0; overflow-y: auto; display: flex; flex-direction: column; gap: 0.5rem; }
|
| 443 |
+
#middle-panel, #chat-panel { display: flex; flex-direction: column; height: 100%; }
|
| 444 |
+
#tabs-container { flex: 1; min-height: 0; display: flex; flex-direction: column; }
|
| 445 |
+
#tabs-container .tabitem { flex: 1; min-height: 0; overflow: auto; }
|
| 446 |
+
#chat-history { flex: 1; min-height: 0; overflow-y: auto; margin-bottom: 0.5rem; }
|
| 447 |
+
#chat-input-row { flex-shrink: 0; display: flex; gap: 0.5rem; }
|
| 448 |
+
.preview-table { border-collapse: collapse; width: 100%; font-size: 0.875rem; }
|
| 449 |
+
.preview-table th { background-color: #3498db; color: white; padding: 8px; text-align: left; position: sticky; top: 0; }
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
|
| 453 |
+
with gr.Column(elem_id="app-container"):
|
| 454 |
+
gr.Markdown("# 🏥 Medical Data Analysis Platform")
|
| 455 |
+
with gr.Row():
|
| 456 |
+
file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
|
| 457 |
+
data_type = gr.Dropdown(choices=["EHR Data", "ECG Data"], value="EHR Data", label="Data Type", scale=1)
|
| 458 |
+
|
| 459 |
+
with gr.Row(elem_id="main-row"):
|
| 460 |
+
with gr.Column(scale=2, elem_id="left-panel"):
|
| 461 |
+
with gr.Group(elem_id="task-section"):
|
| 462 |
+
gr.Markdown("#### Analysis Tasks")
|
| 463 |
+
task_selector = gr.CheckboxGroup(choices=TaskRegistry.get_tasks_for_data_type("EHR Data"), label=None)
|
| 464 |
+
with gr.Group(visible=False) as ndd_param_group:
|
| 465 |
+
gr.Markdown("**Near-Duplicate Detection Parameters**")
|
| 466 |
+
ndd_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
|
| 467 |
+
with gr.Group(visible=False) as mislabel_param_group:
|
| 468 |
+
gr.Markdown("**Find Mislabeled Data Parameters**")
|
| 469 |
+
mislabel_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
|
| 470 |
+
with gr.Group(visible=False) as ecg_viz_param_group:
|
| 471 |
+
gr.Markdown("**ECG Visualization Parameters**")
|
| 472 |
+
ecg_viz_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
|
| 473 |
+
ecg_viz_types = gr.CheckboxGroup(
|
| 474 |
+
choices=["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"],
|
| 475 |
+
label="Visualization Types",
|
| 476 |
+
value=["Signal Waveform", "Histogram"]
|
| 477 |
+
)
|
| 478 |
+
with gr.Group(visible=False) as ecg_stats_param_group:
|
| 479 |
+
gr.Markdown("**Statistical Summary Parameters**")
|
| 480 |
+
ecg_stats_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
|
| 481 |
+
process_btn = gr.Button("▶ Process", variant="primary")
|
| 482 |
+
status_output = gr.Textbox(label="Status", interactive=False, lines=2)
|
| 483 |
+
|
| 484 |
+
with gr.Column(scale=7, elem_id="middle-panel"):
|
| 485 |
+
with gr.Tabs(elem_id="tabs-container"):
|
| 486 |
+
with gr.TabItem("Original Data", visible=False) as tab_original:
|
| 487 |
+
original_df_output = gr.DataFrame(interactive=False)
|
| 488 |
+
with gr.TabItem("Processed Data", visible=False) as tab_processed:
|
| 489 |
+
processed_df_output = gr.DataFrame(interactive=False)
|
| 490 |
+
with gr.TabItem("Summary", visible=False) as tab_summary:
|
| 491 |
+
summary_output = gr.JSON()
|
| 492 |
+
with gr.TabItem("Visualization", visible=False) as tab_viz:
|
| 493 |
+
viz_output = gr.HTML()
|
| 494 |
+
|
| 495 |
+
with gr.Column(scale=3, elem_id="chat-panel"):
|
| 496 |
+
gr.Markdown("### 💬 AI Assistant")
|
| 497 |
+
chatbot = gr.Chatbot(elem_id="chat-history", height="100%")
|
| 498 |
+
with gr.Row(elem_id="chat-input-row"):
|
| 499 |
+
msg_input = gr.Textbox(placeholder="Ask or send a JSON command...", scale=4, container=False)
|
| 500 |
+
send_btn = gr.Button("Send", scale=1)
|
| 501 |
+
|
| 502 |
+
analysis_outputs = [
|
| 503 |
+
status_output, original_df_output, processed_df_output, summary_output, viz_output,
|
| 504 |
+
tab_original, tab_processed, tab_summary, tab_viz
|
| 505 |
+
]
|
| 506 |
+
|
| 507 |
+
file_input.change(
|
| 508 |
+
fn=ui_manager.on_file_upload, inputs=[file_input, data_type],
|
| 509 |
+
outputs=[status_output, original_df_output, task_selector,
|
| 510 |
+
ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
|
| 511 |
+
ndd_label_dropdown, mislabel_label_dropdown,
|
| 512 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types,
|
| 513 |
+
tab_original, tab_processed, tab_summary, tab_viz]
|
| 514 |
+
)
|
| 515 |
+
data_type.change(
|
| 516 |
+
fn=ui_manager.on_data_type_change, inputs=[data_type, file_input],
|
| 517 |
+
outputs=[task_selector, ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
|
| 518 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, status_output]
|
| 519 |
+
)
|
| 520 |
+
task_selector.change(
|
| 521 |
+
fn=ui_manager.on_tasks_change, inputs=[task_selector],
|
| 522 |
+
outputs=[ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group]
|
| 523 |
+
)
|
| 524 |
+
process_btn.click(
|
| 525 |
+
fn=ui_manager.process_analysis,
|
| 526 |
+
inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown,
|
| 527 |
+
ecg_viz_leads, ecg_viz_types, ecg_stats_leads],
|
| 528 |
+
outputs=analysis_outputs
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
chat_submit_args = {"fn": ui_manager.chatbot_respond, "inputs": [msg_input, chatbot], "outputs": [chatbot, msg_input] + analysis_outputs}
|
| 532 |
+
send_btn.click(**chat_submit_args)
|
| 533 |
+
msg_input.submit(**chat_submit_args)
|
| 534 |
+
|
| 535 |
+
return demo
|
| 536 |
+
|
| 537 |
+
# ============================================================================
|
| 538 |
+
# LAUNCH
|
| 539 |
+
# ============================================================================
|
| 540 |
+
|
| 541 |
+
if __name__ == "__main__":
|
| 542 |
+
demo = create_interface()
|
| 543 |
+
demo.launch(share=False, server_name="0.0.0.0", server_port=7890)
|
app_test.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ADDING NEW ANALYSIS TASKS:
|
| 3 |
+
==========================
|
| 4 |
+
|
| 5 |
+
1. Add task configuration in TaskRegistry.get_config():
|
| 6 |
+
"Your Data Type": {
|
| 7 |
+
"Your Task Name": TaskConfig(
|
| 8 |
+
name="Your Task Name",
|
| 9 |
+
data_type="Your Data Type",
|
| 10 |
+
requires_params=True, # Set to True if needs parameters
|
| 11 |
+
param_components=[...], # Define parameter types
|
| 12 |
+
output_tabs=["original", "summary", ...] # Which tabs to show
|
| 13 |
+
)
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
2. Add execution method in AnalysisExecutor:
|
| 17 |
+
@staticmethod
|
| 18 |
+
def execute_your_task(df, param1, param2):
|
| 19 |
+
# Your analysis logic
|
| 20 |
+
return "✓ Done", {
|
| 21 |
+
"original": df,
|
| 22 |
+
"summary": summary_data,
|
| 23 |
+
...
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
3. In create_interface(), add parameter group in the LEFT panel:
|
| 27 |
+
with gr.Group(visible=False) as your_task_param_group:
|
| 28 |
+
gr.Markdown("**Your Task Parameters**")
|
| 29 |
+
your_param1 = gr.Slider(minimum=0, maximum=1, label="Threshold")
|
| 30 |
+
your_param2 = gr.Dropdown(choices=["A", "B"], label="Option")
|
| 31 |
+
|
| 32 |
+
4. In UIManager.on_tasks_change()
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import pandas as pd
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
import numpy as np
|
| 38 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 39 |
+
import io
|
| 40 |
+
import base64
|
| 41 |
+
from tqdm.auto import tqdm
|
| 42 |
+
from dataclasses import dataclass
|
| 43 |
+
import gradio as gr
|
| 44 |
+
# Direct imports (your existing modules)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
from pipeline.deduplication import find_near_duplicates
|
| 48 |
+
from pipeline.featurizer import custom_featurizer
|
| 49 |
+
from pipeline.issues import find_issues
|
| 50 |
+
from pipeline.pipeline import make_step, run_pipeline
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
from pipeline import make_step, run_pipeline
|
| 55 |
+
|
| 56 |
+
# ============================================================================
|
| 57 |
+
# ANALYSIS TASK CONFIGURATION
|
| 58 |
+
# ============================================================================
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class TaskConfig:
|
| 62 |
+
"""Configuration for each analysis task"""
|
| 63 |
+
name: str
|
| 64 |
+
data_type: str
|
| 65 |
+
requires_params: bool # Does it need additional parameters?
|
| 66 |
+
param_components: List[Dict[str, Any]] # List of parameter UI components
|
| 67 |
+
output_tabs: List[str] # Which tabs to show: "original", "processed", "summary", "visualization"
|
| 68 |
+
|
| 69 |
+
class TaskRegistry:
|
| 70 |
+
"""Registry mapping tasks to their configurations"""
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def get_config(data_type: str, task_name: str) -> TaskConfig:
|
| 74 |
+
"""Get configuration for a specific task"""
|
| 75 |
+
configs = {
|
| 76 |
+
"EHR Data": {
|
| 77 |
+
"Near-Duplicate Detection": TaskConfig(
|
| 78 |
+
name="Near-Duplicate Detection",
|
| 79 |
+
data_type="EHR Data",
|
| 80 |
+
requires_params=True,
|
| 81 |
+
param_components=[
|
| 82 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
|
| 83 |
+
],
|
| 84 |
+
output_tabs=["original", "processed", "summary"]
|
| 85 |
+
),
|
| 86 |
+
"Find Mislabeled Data": TaskConfig(
|
| 87 |
+
name="Find Mislabeled Data",
|
| 88 |
+
data_type="EHR Data",
|
| 89 |
+
requires_params=True,
|
| 90 |
+
param_components=[
|
| 91 |
+
{"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
|
| 92 |
+
],
|
| 93 |
+
output_tabs=["original", "summary"]
|
| 94 |
+
)
|
| 95 |
+
},
|
| 96 |
+
"ECG Data": {
|
| 97 |
+
"ECG Visualization": TaskConfig(
|
| 98 |
+
name="ECG Visualization",
|
| 99 |
+
data_type="ECG Data",
|
| 100 |
+
requires_params=False,
|
| 101 |
+
param_components=[],
|
| 102 |
+
output_tabs=["visualization", "summary"]
|
| 103 |
+
),
|
| 104 |
+
"Statistical Summary": TaskConfig(
|
| 105 |
+
name="Statistical Summary",
|
| 106 |
+
data_type="ECG Data",
|
| 107 |
+
requires_params=False,
|
| 108 |
+
param_components=[],
|
| 109 |
+
output_tabs=["summary"]
|
| 110 |
+
)
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
return configs.get(data_type, {}).get(task_name)
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def get_tasks_for_data_type(data_type: str) -> List[str]:
|
| 117 |
+
"""Get available tasks for a data type"""
|
| 118 |
+
tasks = {
|
| 119 |
+
"EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
|
| 120 |
+
"ECG Data": ["ECG Visualization", "Statistical Summary"]
|
| 121 |
+
}
|
| 122 |
+
return tasks.get(data_type, [])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ============================================================================
|
| 126 |
+
# ANALYSIS EXECUTION
|
| 127 |
+
# ============================================================================
|
| 128 |
+
|
| 129 |
+
class AnalysisExecutor:
|
| 130 |
+
"""Executes analysis tasks and returns results"""
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 134 |
+
"""Execute near-duplicate detection pipeline"""
|
| 135 |
+
try:
|
| 136 |
+
if not label:
|
| 137 |
+
return "⚠ Label column required", {
|
| 138 |
+
"original": df,
|
| 139 |
+
"processed": None,
|
| 140 |
+
"summary": None
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
|
| 144 |
+
steps = [
|
| 145 |
+
make_step(find_near_duplicates, name="dedup")(progress=bar),
|
| 146 |
+
make_step(custom_featurizer, name="featurize")(
|
| 147 |
+
label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
|
| 148 |
+
),
|
| 149 |
+
make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
|
| 150 |
+
]
|
| 151 |
+
results_df, summary_list = run_pipeline(steps, df=df)
|
| 152 |
+
bar.close()
|
| 153 |
+
|
| 154 |
+
return "✓ Near-duplicate detection completed", {
|
| 155 |
+
"original": df,
|
| 156 |
+
"processed": results_df,
|
| 157 |
+
"summary": summary_list
|
| 158 |
+
}
|
| 159 |
+
except Exception as e:
|
| 160 |
+
return f"✗ Error: {str(e)}", {
|
| 161 |
+
"original": df,
|
| 162 |
+
"processed": None,
|
| 163 |
+
"summary": None
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
|
| 168 |
+
"""Execute mislabeled data detection"""
|
| 169 |
+
try:
|
| 170 |
+
if not label:
|
| 171 |
+
return "⚠ Label column required", {
|
| 172 |
+
"original": df,
|
| 173 |
+
"summary": None
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# Placeholder for actual mislabeled detection logic
|
| 177 |
+
summary = {
|
| 178 |
+
"task": "Find Mislabeled Data",
|
| 179 |
+
"label_column": label,
|
| 180 |
+
"total_samples": len(df),
|
| 181 |
+
"suspicious_samples": 0,
|
| 182 |
+
"message": "Mislabeled detection analysis completed"
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
return "✓ Mislabeled data analysis completed", {
|
| 186 |
+
"original": df,
|
| 187 |
+
"summary": summary
|
| 188 |
+
}
|
| 189 |
+
except Exception as e:
|
| 190 |
+
return f"✗ Error: {str(e)}", {
|
| 191 |
+
"original": df,
|
| 192 |
+
"summary": None
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
@staticmethod
|
| 196 |
+
def execute_ecg_visualization(df: pd.DataFrame) -> Tuple[str, Dict[str, Any]]:
|
| 197 |
+
"""Execute ECG visualization"""
|
| 198 |
+
try:
|
| 199 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 200 |
+
if len(df.columns) > 1:
|
| 201 |
+
for col in df.columns[1:]:
|
| 202 |
+
ax.plot(df.iloc[:, 0], df[col], label=str(col), linewidth=0.8)
|
| 203 |
+
ax.set_xlabel('Time (ms)')
|
| 204 |
+
ax.set_ylabel('Amplitude (mV)')
|
| 205 |
+
ax.set_title('ECG Signal Visualization')
|
| 206 |
+
ax.legend()
|
| 207 |
+
else:
|
| 208 |
+
ax.plot(df.iloc[:, 0], linewidth=0.8)
|
| 209 |
+
ax.set_xlabel('Sample Index')
|
| 210 |
+
ax.set_ylabel('Amplitude')
|
| 211 |
+
ax.set_title('ECG Signal')
|
| 212 |
+
ax.grid(True, alpha=0.3)
|
| 213 |
+
plt.tight_layout()
|
| 214 |
+
|
| 215 |
+
# Convert plot to base64
|
| 216 |
+
buf = io.BytesIO()
|
| 217 |
+
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 218 |
+
buf.seek(0)
|
| 219 |
+
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
| 220 |
+
plt.close(fig)
|
| 221 |
+
|
| 222 |
+
viz_html = f'<div style="text-align:center;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>'
|
| 223 |
+
|
| 224 |
+
summary = {
|
| 225 |
+
"task": "ECG Visualization",
|
| 226 |
+
"samples": len(df),
|
| 227 |
+
"channels": len(df.columns) - 1 if len(df.columns) > 1 else 1
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
return "✓ ECG visualization created", {
|
| 231 |
+
"visualization": viz_html,
|
| 232 |
+
"summary": summary
|
| 233 |
+
}
|
| 234 |
+
except Exception as e:
|
| 235 |
+
return f"✗ Error: {str(e)}", {
|
| 236 |
+
"visualization": None,
|
| 237 |
+
"summary": None
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def execute_statistical_summary(df: pd.DataFrame) -> Tuple[str, Dict[str, Any]]:
|
| 242 |
+
"""Execute statistical summary"""
|
| 243 |
+
try:
|
| 244 |
+
stats = df.describe().to_html(classes='preview-table')
|
| 245 |
+
summary_html = f"<h3>Statistical Summary</h3><div style='overflow-x:auto;'>{stats}</div>"
|
| 246 |
+
|
| 247 |
+
summary = {
|
| 248 |
+
"task": "Statistical Summary",
|
| 249 |
+
"rows": len(df),
|
| 250 |
+
"columns": len(df.columns),
|
| 251 |
+
"numeric_columns": len(df.select_dtypes(include=[np.number]).columns),
|
| 252 |
+
"html": summary_html
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
return "✓ Statistical summary generated", {
|
| 256 |
+
"summary": summary,
|
| 257 |
+
"visualization": summary_html
|
| 258 |
+
}
|
| 259 |
+
except Exception as e:
|
| 260 |
+
return f"✗ Error: {str(e)}", {
|
| 261 |
+
"summary": None,
|
| 262 |
+
"visualization": None
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ============================================================================
|
| 267 |
+
# UI MANAGER - Handles all UI state and updates
|
| 268 |
+
# ============================================================================
|
| 269 |
+
|
| 270 |
+
class UIManager:
|
| 271 |
+
"""Manages UI state and dynamic updates"""
|
| 272 |
+
|
| 273 |
+
def __init__(self):
|
| 274 |
+
self.current_df = None
|
| 275 |
+
self.current_data_type = "EHR Data"
|
| 276 |
+
self.chatbot_context = {}
|
| 277 |
+
|
| 278 |
+
def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
|
| 279 |
+
"""Load CSV file"""
|
| 280 |
+
if file is None:
|
| 281 |
+
return "⚠ No file uploaded", None
|
| 282 |
+
try:
|
| 283 |
+
df = pd.read_csv(file.name)
|
| 284 |
+
self.current_df = df
|
| 285 |
+
return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
|
| 286 |
+
except Exception as e:
|
| 287 |
+
return f"✗ Error: {str(e)}", None
|
| 288 |
+
|
| 289 |
+
def on_file_upload(self, file, data_type: str):
|
| 290 |
+
"""Handle file upload - returns updates for all components"""
|
| 291 |
+
status, df = self.load_csv(file)
|
| 292 |
+
|
| 293 |
+
if df is None:
|
| 294 |
+
return (
|
| 295 |
+
status, # status
|
| 296 |
+
gr.update(value=None), # original_df
|
| 297 |
+
gr.update(choices=[], value=[], interactive=True), # task_checkboxes
|
| 298 |
+
gr.update(visible=False), # ndd_param_group
|
| 299 |
+
gr.update(visible=False), # mislabel_param_group
|
| 300 |
+
gr.update(choices=[]), # ndd_label_dropdown
|
| 301 |
+
gr.update(choices=[]), # mislabel_label_dropdown
|
| 302 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # tabs
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
|
| 306 |
+
available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
|
| 307 |
+
col_choices = list(df.columns)
|
| 308 |
+
|
| 309 |
+
return (
|
| 310 |
+
status,
|
| 311 |
+
gr.update(value=df.head(200)),
|
| 312 |
+
gr.update(choices=available_tasks, value=[], interactive=True),
|
| 313 |
+
gr.update(visible=False),
|
| 314 |
+
gr.update(visible=False),
|
| 315 |
+
gr.update(choices=col_choices, value=None),
|
| 316 |
+
gr.update(choices=col_choices, value=None),
|
| 317 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def on_data_type_change(self, data_type: str, file):
|
| 321 |
+
"""Handle data type change"""
|
| 322 |
+
self.current_data_type = data_type
|
| 323 |
+
available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
|
| 324 |
+
|
| 325 |
+
if file and self.current_df is not None:
|
| 326 |
+
self.chatbot_context["type"] = data_type
|
| 327 |
+
|
| 328 |
+
return (
|
| 329 |
+
gr.update(choices=available_tasks, value=[]), # Reset task selection
|
| 330 |
+
gr.update(visible=False), # Hide ndd params
|
| 331 |
+
gr.update(visible=False), # Hide mislabel params
|
| 332 |
+
f"Data type changed to: {data_type}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def on_tasks_change(self, selected_tasks: List[str], data_type: str, df_columns: List[str]):
|
| 336 |
+
"""Handle task selection change - show/hide parameter groups for all selected tasks"""
|
| 337 |
+
if not selected_tasks:
|
| 338 |
+
# Hide all parameter groups
|
| 339 |
+
return (
|
| 340 |
+
gr.update(visible=False), # ndd_param_group
|
| 341 |
+
gr.update(visible=False), # mislabel_param_group
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Check which tasks need which parameters
|
| 345 |
+
show_ndd_params = "Near-Duplicate Detection" in selected_tasks
|
| 346 |
+
show_mislabel_params = "Find Mislabeled Data" in selected_tasks
|
| 347 |
+
|
| 348 |
+
return (
|
| 349 |
+
gr.update(visible=show_ndd_params),
|
| 350 |
+
gr.update(visible=show_mislabel_params),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def process_analysis(self, file, data_type: str, selected_tasks: List[str], ndd_label: str, mislabel_label: str):
|
| 354 |
+
"""Process selected analysis tasks - handles multiple tasks"""
|
| 355 |
+
status, df = self.load_csv(file)
|
| 356 |
+
|
| 357 |
+
if df is None:
|
| 358 |
+
return (
|
| 359 |
+
status,
|
| 360 |
+
None, None, None, None,
|
| 361 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
if not selected_tasks:
|
| 365 |
+
return (
|
| 366 |
+
"⚠ No tasks selected",
|
| 367 |
+
df.head(200), None, None, None,
|
| 368 |
+
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Track which tabs to show (union of all selected tasks)
|
| 372 |
+
all_tabs = set()
|
| 373 |
+
all_results = {
|
| 374 |
+
"original": df.head(200),
|
| 375 |
+
"processed": None,
|
| 376 |
+
"summary": [],
|
| 377 |
+
"visualization": ""
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
status_messages = []
|
| 381 |
+
executor = AnalysisExecutor()
|
| 382 |
+
|
| 383 |
+
# Process each selected task
|
| 384 |
+
for task_name in selected_tasks:
|
| 385 |
+
config = TaskRegistry.get_config(data_type, task_name)
|
| 386 |
+
|
| 387 |
+
if not config:
|
| 388 |
+
status_messages.append(f"✗ Unknown task: {task_name}")
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
# Add this task's tabs to the set
|
| 392 |
+
all_tabs.update(config.output_tabs)
|
| 393 |
+
|
| 394 |
+
# Execute the task with appropriate parameters
|
| 395 |
+
if task_name == "Near-Duplicate Detection":
|
| 396 |
+
status_msg, results = executor.execute_near_duplicate_detection(df, ndd_label)
|
| 397 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 398 |
+
if results.get("processed") is not None:
|
| 399 |
+
all_results["processed"] = results["processed"]
|
| 400 |
+
if results.get("summary") is not None:
|
| 401 |
+
all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 402 |
+
|
| 403 |
+
elif task_name == "Find Mislabeled Data":
|
| 404 |
+
status_msg, results = executor.execute_find_mislabeled(df, mislabel_label)
|
| 405 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 406 |
+
if results.get("summary") is not None:
|
| 407 |
+
all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 408 |
+
|
| 409 |
+
elif task_name == "ECG Visualization":
|
| 410 |
+
status_msg, results = executor.execute_ecg_visualization(df)
|
| 411 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 412 |
+
if results.get("visualization"):
|
| 413 |
+
all_results["visualization"] += results["visualization"]
|
| 414 |
+
if results.get("summary") is not None:
|
| 415 |
+
all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 416 |
+
|
| 417 |
+
elif task_name == "Statistical Summary":
|
| 418 |
+
status_msg, results = executor.execute_statistical_summary(df)
|
| 419 |
+
status_messages.append(f"{task_name}: {status_msg}")
|
| 420 |
+
if results.get("visualization"):
|
| 421 |
+
all_results["visualization"] += results["visualization"]
|
| 422 |
+
if results.get("summary") is not None:
|
| 423 |
+
all_results["summary"].append({"task": task_name, "data": results["summary"]})
|
| 424 |
+
|
| 425 |
+
# Format final outputs based on all_tabs
|
| 426 |
+
show_original = gr.update(visible="original" in all_tabs)
|
| 427 |
+
show_processed = gr.update(visible="processed" in all_tabs)
|
| 428 |
+
show_summary = gr.update(visible="summary" in all_tabs)
|
| 429 |
+
show_viz = gr.update(visible="visualization" in all_tabs)
|
| 430 |
+
|
| 431 |
+
# Combine status messages
|
| 432 |
+
final_status = "\n".join(status_messages)
|
| 433 |
+
|
| 434 |
+
return (
|
| 435 |
+
final_status,
|
| 436 |
+
all_results["original"],
|
| 437 |
+
all_results["processed"],
|
| 438 |
+
all_results["summary"] if all_results["summary"] else None,
|
| 439 |
+
all_results["visualization"] if all_results["visualization"] else None,
|
| 440 |
+
show_original,
|
| 441 |
+
show_processed,
|
| 442 |
+
show_summary,
|
| 443 |
+
show_viz
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def _format_results(self, status: str, output_tabs: List[str], results: Dict[str, Any]):
|
| 447 |
+
"""Format results based on output tabs configuration"""
|
| 448 |
+
# Default all outputs to None
|
| 449 |
+
original_out = None
|
| 450 |
+
processed_out = None
|
| 451 |
+
summary_out = None
|
| 452 |
+
viz_out = None
|
| 453 |
+
|
| 454 |
+
# Default all tabs hidden
|
| 455 |
+
show_original = gr.update(visible=False)
|
| 456 |
+
show_processed = gr.update(visible=False)
|
| 457 |
+
show_summary = gr.update(visible=False)
|
| 458 |
+
show_viz = gr.update(visible=False)
|
| 459 |
+
|
| 460 |
+
# Populate based on output_tabs
|
| 461 |
+
if "original" in output_tabs:
|
| 462 |
+
original_out = results.get("original")
|
| 463 |
+
show_original = gr.update(visible=True)
|
| 464 |
+
|
| 465 |
+
if "processed" in output_tabs:
|
| 466 |
+
processed_out = results.get("processed")
|
| 467 |
+
show_processed = gr.update(visible=True)
|
| 468 |
+
|
| 469 |
+
if "summary" in output_tabs:
|
| 470 |
+
summary_data = results.get("summary")
|
| 471 |
+
if results.get("summary_html"): # For statistical summary
|
| 472 |
+
summary_out = results.get("summary_html")
|
| 473 |
+
else:
|
| 474 |
+
summary_out = summary_data
|
| 475 |
+
show_summary = gr.update(visible=True)
|
| 476 |
+
|
| 477 |
+
if "visualization" in output_tabs:
|
| 478 |
+
viz_out = results.get("visualization")
|
| 479 |
+
show_viz = gr.update(visible=True)
|
| 480 |
+
|
| 481 |
+
return (
|
| 482 |
+
status,
|
| 483 |
+
original_out,
|
| 484 |
+
processed_out,
|
| 485 |
+
summary_out,
|
| 486 |
+
viz_out,
|
| 487 |
+
show_original,
|
| 488 |
+
show_processed,
|
| 489 |
+
show_summary,
|
| 490 |
+
show_viz
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def chatbot_respond(self, message: str, history: List):
|
| 494 |
+
"""Chatbot response"""
|
| 495 |
+
if history is None:
|
| 496 |
+
history = []
|
| 497 |
+
if not message or message.strip() == "":
|
| 498 |
+
return history, ""
|
| 499 |
+
|
| 500 |
+
df = self.chatbot_context.get("df")
|
| 501 |
+
|
| 502 |
+
if "column" in message.lower():
|
| 503 |
+
response = (f"Dataset has {len(df.columns)} columns: {', '.join(map(str, df.columns))}"
|
| 504 |
+
if df is not None else "Please upload a file first.")
|
| 505 |
+
elif "row" in message.lower():
|
| 506 |
+
response = f"Dataset has {len(df)} rows." if df is not None else "Please upload a file first."
|
| 507 |
+
elif "help" in message.lower():
|
| 508 |
+
response = "Ask about 'columns', 'rows', or your data analysis."
|
| 509 |
+
else:
|
| 510 |
+
response = f"You asked: '{message}'. Analyzing: {self.chatbot_context.get('type', 'No data')}."
|
| 511 |
+
|
| 512 |
+
history.append((message, response))
|
| 513 |
+
return history, ""
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# ============================================================================
|
| 517 |
+
# GRADIO INTERFACE
|
| 518 |
+
# ============================================================================
|
| 519 |
+
|
| 520 |
+
def create_interface():
|
| 521 |
+
"""Build the Gradio interface"""
|
| 522 |
+
|
| 523 |
+
ui_manager = UIManager()
|
| 524 |
+
|
| 525 |
+
custom_css = """
|
| 526 |
+
* { box-sizing: border-box; }
|
| 527 |
+
html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
|
| 528 |
+
.gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
|
| 529 |
+
|
| 530 |
+
#app-container {
|
| 531 |
+
height: 100vh;
|
| 532 |
+
display: flex;
|
| 533 |
+
flex-direction: column;
|
| 534 |
+
padding: 0.75rem;
|
| 535 |
+
gap: 0.75rem;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
#main-row {
|
| 539 |
+
flex: 1;
|
| 540 |
+
min-height: 0;
|
| 541 |
+
display: flex;
|
| 542 |
+
gap: 0.75rem;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
#left-panel {
|
| 546 |
+
display: flex;
|
| 547 |
+
flex-direction: column;
|
| 548 |
+
height: 100%;
|
| 549 |
+
background: #f9fafb;
|
| 550 |
+
border-radius: 10px;
|
| 551 |
+
padding: 0.75rem;
|
| 552 |
+
gap: 0.5rem;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
#task-section {
|
| 556 |
+
flex: 1;
|
| 557 |
+
min-height: 0;
|
| 558 |
+
overflow-y: auto;
|
| 559 |
+
display: flex;
|
| 560 |
+
flex-direction: column;
|
| 561 |
+
gap: 0.5rem;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
#action-section {
|
| 565 |
+
flex-shrink: 0;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
#middle-panel {
|
| 569 |
+
display: flex;
|
| 570 |
+
flex-direction: column;
|
| 571 |
+
height: 100%;
|
| 572 |
+
min-width: 0;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
#tabs-container {
|
| 576 |
+
flex: 1;
|
| 577 |
+
min-height: 0;
|
| 578 |
+
display: flex;
|
| 579 |
+
flex-direction: column;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
#tabs-container .tabs {
|
| 583 |
+
height: 100%;
|
| 584 |
+
display: flex;
|
| 585 |
+
flex-direction: column;
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
#tabs-container .tab-nav {
|
| 589 |
+
flex-shrink: 0;
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
#tabs-container .tabitem {
|
| 593 |
+
flex: 1;
|
| 594 |
+
min-height: 0;
|
| 595 |
+
overflow: auto;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
#chat-panel {
|
| 599 |
+
display: flex;
|
| 600 |
+
flex-direction: column;
|
| 601 |
+
height: 100%;
|
| 602 |
+
background: #f9fafb;
|
| 603 |
+
border-radius: 10px;
|
| 604 |
+
padding: 0.75rem;
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
#chat-header {
|
| 608 |
+
flex-shrink: 0;
|
| 609 |
+
margin-bottom: 0.5rem;
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
#chat-history {
|
| 613 |
+
flex: 1;
|
| 614 |
+
min-height: 0;
|
| 615 |
+
overflow-y: auto;
|
| 616 |
+
margin-bottom: 0.5rem;
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
#chat-input-row {
|
| 620 |
+
flex-shrink: 0;
|
| 621 |
+
display: flex;
|
| 622 |
+
gap: 0.5rem;
|
| 623 |
+
align-items: flex-end;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
#chat-input-row .textbox {
|
| 627 |
+
flex: 1;
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
#chat-input-row button {
|
| 631 |
+
flex-shrink: 0;
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
.preview-table {
|
| 635 |
+
border-collapse: collapse;
|
| 636 |
+
width: 100%;
|
| 637 |
+
font-size: 0.875rem;
|
| 638 |
+
}
|
| 639 |
+
.preview-table th {
|
| 640 |
+
background-color: #3498db;
|
| 641 |
+
color: white;
|
| 642 |
+
padding: 8px;
|
| 643 |
+
text-align: left;
|
| 644 |
+
position: sticky;
|
| 645 |
+
top: 0;
|
| 646 |
+
}
|
| 647 |
+
.preview-table td {
|
| 648 |
+
padding: 6px;
|
| 649 |
+
border-bottom: 1px solid #ddd;
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
.compact-header { font-size: 0.95rem; margin: 0; }
|
| 653 |
+
"""
|
| 654 |
+
|
| 655 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
|
| 656 |
+
|
| 657 |
+
with gr.Column(elem_id="app-container"):
|
| 658 |
+
# Header
|
| 659 |
+
gr.Markdown("# 🏥 Medical Data Analysis Platform", elem_classes=["compact-header"])
|
| 660 |
+
|
| 661 |
+
# Top controls
|
| 662 |
+
with gr.Row():
|
| 663 |
+
file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
|
| 664 |
+
data_type = gr.Dropdown(
|
| 665 |
+
choices=["EHR Data", "ECG Data"],
|
| 666 |
+
value="EHR Data",
|
| 667 |
+
label="Data Type",
|
| 668 |
+
scale=1
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Main row with 3 columns
|
| 672 |
+
with gr.Row(elem_id="main-row"):
|
| 673 |
+
|
| 674 |
+
# LEFT: Analysis tasks
|
| 675 |
+
with gr.Column(scale=2, elem_id="left-panel"):
|
| 676 |
+
with gr.Group(elem_id="task-section"):
|
| 677 |
+
gr.Markdown("#### Analysis Tasks")
|
| 678 |
+
task_selector = gr.CheckboxGroup(
|
| 679 |
+
choices=TaskRegistry.get_tasks_for_data_type("EHR Data"),
|
| 680 |
+
label=None,
|
| 681 |
+
elem_id="task-checkboxes"
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# Parameter groups (conditionally visible) - one for each task that needs params
|
| 685 |
+
with gr.Group(visible=False) as ndd_param_group:
|
| 686 |
+
gr.Markdown("**Near-Duplicate Detection Parameters**")
|
| 687 |
+
ndd_label_dropdown = gr.Dropdown(
|
| 688 |
+
choices=[],
|
| 689 |
+
label="Label Column",
|
| 690 |
+
elem_id="ndd-label-dropdown"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
with gr.Group(visible=False) as mislabel_param_group:
|
| 694 |
+
gr.Markdown("**Find Mislabeled Data Parameters**")
|
| 695 |
+
mislabel_label_dropdown = gr.Dropdown(
|
| 696 |
+
choices=[],
|
| 697 |
+
label="Label Column",
|
| 698 |
+
elem_id="mislabel-label-dropdown"
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
with gr.Group(elem_id="action-section"):
|
| 702 |
+
process_btn = gr.Button("▶ Process", variant="primary", size="lg")
|
| 703 |
+
status_output = gr.Textbox(label="Status", interactive=False, lines=2)
|
| 704 |
+
|
| 705 |
+
# MIDDLE: Results display
|
| 706 |
+
with gr.Column(scale=7, elem_id="middle-panel"):
|
| 707 |
+
with gr.Tabs(elem_id="tabs-container") as result_tabs:
|
| 708 |
+
with gr.TabItem("Original Data", visible=False) as tab_original:
|
| 709 |
+
original_df_output = gr.Dataframe(interactive=False)
|
| 710 |
+
|
| 711 |
+
with gr.TabItem("Processed Data", visible=False) as tab_processed:
|
| 712 |
+
processed_df_output = gr.Dataframe(interactive=False)
|
| 713 |
+
|
| 714 |
+
with gr.TabItem("Summary", visible=False) as tab_summary:
|
| 715 |
+
summary_output = gr.JSON()
|
| 716 |
+
|
| 717 |
+
with gr.TabItem("Visualization", visible=False) as tab_viz:
|
| 718 |
+
viz_output = gr.HTML()
|
| 719 |
+
|
| 720 |
+
# RIGHT: Chat
|
| 721 |
+
with gr.Column(scale=3, elem_id="chat-panel"):
|
| 722 |
+
gr.Markdown("### 💬 AI Assistant", elem_id="chat-header")
|
| 723 |
+
chatbot = gr.Chatbot(elem_id="chat-history", height=None, label=None)
|
| 724 |
+
with gr.Row(elem_id="chat-input-row"):
|
| 725 |
+
msg_input = gr.Textbox(
|
| 726 |
+
placeholder="Ask about your data...",
|
| 727 |
+
label="",
|
| 728 |
+
scale=4,
|
| 729 |
+
container=False,
|
| 730 |
+
lines=1
|
| 731 |
+
)
|
| 732 |
+
send_btn = gr.Button("Send", scale=1, size="sm")
|
| 733 |
+
|
| 734 |
+
# ============ EVENT HANDLERS ============
|
| 735 |
+
|
| 736 |
+
# File upload
|
| 737 |
+
file_input.change(
|
| 738 |
+
fn=ui_manager.on_file_upload,
|
| 739 |
+
inputs=[file_input, data_type],
|
| 740 |
+
outputs=[
|
| 741 |
+
status_output,
|
| 742 |
+
original_df_output,
|
| 743 |
+
task_selector,
|
| 744 |
+
ndd_param_group,
|
| 745 |
+
mislabel_param_group,
|
| 746 |
+
ndd_label_dropdown,
|
| 747 |
+
mislabel_label_dropdown,
|
| 748 |
+
tab_original, tab_processed, tab_summary, tab_viz,
|
| 749 |
+
]
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
# Data type change
|
| 753 |
+
data_type.change(
|
| 754 |
+
fn=ui_manager.on_data_type_change,
|
| 755 |
+
inputs=[data_type, file_input],
|
| 756 |
+
outputs=[task_selector, ndd_param_group, mislabel_param_group, status_output]
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Task selection change
|
| 760 |
+
task_selector.change(
|
| 761 |
+
fn=ui_manager.on_tasks_change,
|
| 762 |
+
inputs=[task_selector, data_type, ndd_label_dropdown],
|
| 763 |
+
outputs=[ndd_param_group, mislabel_param_group]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# Process button
|
| 767 |
+
process_btn.click(
|
| 768 |
+
fn=ui_manager.process_analysis,
|
| 769 |
+
inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown],
|
| 770 |
+
outputs=[
|
| 771 |
+
status_output,
|
| 772 |
+
original_df_output,
|
| 773 |
+
processed_df_output,
|
| 774 |
+
summary_output,
|
| 775 |
+
viz_output,
|
| 776 |
+
tab_original,
|
| 777 |
+
tab_processed,
|
| 778 |
+
tab_summary,
|
| 779 |
+
tab_viz
|
| 780 |
+
]
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Chat
|
| 784 |
+
send_btn.click(
|
| 785 |
+
fn=ui_manager.chatbot_respond,
|
| 786 |
+
inputs=[msg_input, chatbot],
|
| 787 |
+
outputs=[chatbot, msg_input]
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
msg_input.submit(
|
| 791 |
+
fn=ui_manager.chatbot_respond,
|
| 792 |
+
inputs=[msg_input, chatbot],
|
| 793 |
+
outputs=[chatbot, msg_input]
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
return demo
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# ============================================================================
|
| 800 |
+
# LAUNCH
|
| 801 |
+
# ============================================================================
|
| 802 |
+
|
| 803 |
+
if __name__ == "__main__":
|
| 804 |
+
demo = create_interface()
|
| 805 |
+
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
# ============================================================================
|
| 809 |
+
# HOW TO EXTEND
|
| 810 |
+
# ============================================================================
|
| 811 |
+
|
| 812 |
+
"""
|
| 813 |
+
ADDING NEW ANALYSIS TASKS:
|
| 814 |
+
==========================
|
| 815 |
+
|
| 816 |
+
1. Add task configuration in TaskRegistry:
|
| 817 |
+
- Define name, data_type, requires_params, param_components, output_tabs
|
| 818 |
+
|
| 819 |
+
2. Add execution method in AnalysisExecutor:
|
| 820 |
+
- Create execute_your_task() method
|
| 821 |
+
- Return (status_message, results_dict)
|
| 822 |
+
|
| 823 |
+
3. Add condition in UIManager.process_analysis():
|
| 824 |
+
- elif task_name == "Your Task Name":
|
| 825 |
+
- status_msg, results = executor.execute_your_task(df, params)
|
| 826 |
+
- return self._format_results(status_msg, config.output_tabs, results)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
EXAMPLE - Adding a new ECG task with parameters:
|
| 830 |
+
================================================
|
| 831 |
+
|
| 832 |
+
In TaskRegistry.get_config():
|
| 833 |
+
"ECG Data": {
|
| 834 |
+
...
|
| 835 |
+
"ECG Quality Check": TaskConfig(
|
| 836 |
+
name="ECG Quality Check",
|
| 837 |
+
data_type="ECG Data",
|
| 838 |
+
requires_params=True,
|
| 839 |
+
param_components=[
|
| 840 |
+
{"type": "slider", "label": "Quality Threshold", "elem_id": "quality_thresh"}
|
| 841 |
+
],
|
| 842 |
+
output_tabs=["original", "summary", "visualization"]
|
| 843 |
+
)
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
In AnalysisExecutor:
|
| 847 |
+
@staticmethod
|
| 848 |
+
def execute_ecg_quality_check(df: pd.DataFrame, threshold: float) -> Tuple[str, Dict[str, Any]]:
|
| 849 |
+
# Your analysis logic here
|
| 850 |
+
return "✓ Quality check completed", {
|
| 851 |
+
"original": df,
|
| 852 |
+
"summary": quality_summary,
|
| 853 |
+
"visualization": viz_html
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
In UIManager.process_analysis():
|
| 857 |
+
elif task_name == "ECG Quality Check":
|
| 858 |
+
status_msg, results = executor.execute_ecg_quality_check(df, float(param_value))
|
| 859 |
+
return self._format_results(status_msg, config.output_tabs, results)
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
The architecture is now:
|
| 863 |
+
- Simple and clean
|
| 864 |
+
- Easy to extend with new tasks
|
| 865 |
+
- Dynamic UI based on task configuration
|
| 866 |
+
- Fixed chat scrolling issue
|
| 867 |
+
"""
|
backend_client.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# backend_client.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import requests
|
| 4 |
+
from typing import Optional, List, Tuple
|
| 5 |
+
|
| 6 |
+
class LangGraphClient:
|
| 7 |
+
"""
|
| 8 |
+
Minimal client to talk to your LangGraph app running on http://localhost:8010.
|
| 9 |
+
Adjust ENDPOINTS to match your server routes/payloads.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, base_url: str = "http://localhost:8010"):
|
| 13 |
+
self.base_url = base_url.rstrip("/")
|
| 14 |
+
# Example endpoints — change if your server differs
|
| 15 |
+
self.endpoints = {
|
| 16 |
+
"chat": f"{self.base_url}/chat" # expects JSON {session_id, message, history?}
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def send_message(
|
| 20 |
+
self,
|
| 21 |
+
message: str,
|
| 22 |
+
session_id: Optional[str] = "default",
|
| 23 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
| 24 |
+
timeout: int = 30
|
| 25 |
+
) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Sends a message to the backend and returns the assistant reply text.
|
| 28 |
+
"""
|
| 29 |
+
payload = {
|
| 30 |
+
"session_id": session_id,
|
| 31 |
+
"message": message,
|
| 32 |
+
"history": history or [] # [[user, assistant], ...] if your server uses it
|
| 33 |
+
}
|
| 34 |
+
try:
|
| 35 |
+
r = requests.post(self.endpoints["chat"], json=payload, timeout=timeout)
|
| 36 |
+
r.raise_for_status()
|
| 37 |
+
data = r.json()
|
| 38 |
+
# Expecting {"reply": "..."}; adjust if your API returns a different shape
|
| 39 |
+
reply = data.get("reply") or data.get("message") or data.get("text") or ""
|
| 40 |
+
return str(reply)
|
| 41 |
+
except requests.RequestException as e:
|
| 42 |
+
return f"[Backend error] {e}"
|
data/Aubrie.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/JS00001_filtered.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/Lisette.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ecg_analyzer.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ECG Analysis Module
|
| 3 |
+
Provides modular ECG visualization and analysis functions
|
| 4 |
+
"""
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import io
|
| 9 |
+
import base64
|
| 10 |
+
from typing import List, Optional, Tuple, Dict, Any
|
| 11 |
+
|
| 12 |
+
class ECGAnalyzer:
|
| 13 |
+
"""Handles ECG data analysis and visualization"""
|
| 14 |
+
|
| 15 |
+
# Standard 12-lead ECG leads
|
| 16 |
+
STANDARD_LEADS = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def detect_leads(df: pd.DataFrame) -> List[str]:
|
| 20 |
+
"""Detect available ECG leads in the dataframe"""
|
| 21 |
+
available_leads = []
|
| 22 |
+
for lead in ECGAnalyzer.STANDARD_LEADS:
|
| 23 |
+
if lead in df.columns:
|
| 24 |
+
available_leads.append(lead)
|
| 25 |
+
return available_leads
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def detect_time_column(df: pd.DataFrame) -> Optional[str]:
|
| 29 |
+
"""Detect the time column in the dataframe"""
|
| 30 |
+
time_candidates = ['time', 'Time', 'TIME', 'timestamp', 'sample']
|
| 31 |
+
for col in time_candidates:
|
| 32 |
+
if col in df.columns:
|
| 33 |
+
return col
|
| 34 |
+
# If no explicit time column, use index
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def create_signal_plot(df: pd.DataFrame, leads: List[str], time_col: Optional[str] = None) -> str:
|
| 39 |
+
"""Create ECG signal waveform plot"""
|
| 40 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 41 |
+
|
| 42 |
+
if time_col and time_col in df.columns:
|
| 43 |
+
x_data = df[time_col]
|
| 44 |
+
x_label = 'Time (ms)' if 'time' in time_col.lower() else time_col
|
| 45 |
+
else:
|
| 46 |
+
x_data = df.index
|
| 47 |
+
x_label = 'Sample Index'
|
| 48 |
+
|
| 49 |
+
colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
|
| 50 |
+
|
| 51 |
+
for idx, lead in enumerate(leads):
|
| 52 |
+
if lead in df.columns:
|
| 53 |
+
ax.plot(x_data, df[lead], label=f'Lead {lead}',
|
| 54 |
+
linewidth=1.2, alpha=0.8, color=colors[idx])
|
| 55 |
+
|
| 56 |
+
ax.set_xlabel(x_label, fontsize=11)
|
| 57 |
+
ax.set_ylabel('Amplitude (mV)', fontsize=11)
|
| 58 |
+
ax.set_title('ECG Signal Waveform', fontsize=13, fontweight='bold')
|
| 59 |
+
ax.legend(loc='upper right', fontsize=9, ncol=min(4, len(leads)))
|
| 60 |
+
ax.grid(True, alpha=0.3, linestyle='--')
|
| 61 |
+
plt.tight_layout()
|
| 62 |
+
|
| 63 |
+
return ECGAnalyzer._fig_to_base64(fig)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def create_histogram(df: pd.DataFrame, leads: List[str]) -> str:
|
| 67 |
+
"""Create histogram of signal amplitudes"""
|
| 68 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 69 |
+
|
| 70 |
+
colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
|
| 71 |
+
|
| 72 |
+
for idx, lead in enumerate(leads):
|
| 73 |
+
if lead in df.columns:
|
| 74 |
+
ax.hist(df[lead].dropna(), bins=50, alpha=0.6,
|
| 75 |
+
label=f'Lead {lead}', color=colors[idx], edgecolor='black')
|
| 76 |
+
|
| 77 |
+
ax.set_xlabel('Amplitude (mV)', fontsize=11)
|
| 78 |
+
ax.set_ylabel('Frequency', fontsize=11)
|
| 79 |
+
ax.set_title('Distribution of Signal Amplitudes', fontsize=13, fontweight='bold')
|
| 80 |
+
ax.legend(loc='upper right', fontsize=9)
|
| 81 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 82 |
+
plt.tight_layout()
|
| 83 |
+
|
| 84 |
+
return ECGAnalyzer._fig_to_base64(fig)
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def create_scatter_plot(df: pd.DataFrame, lead_x: str = 'I', lead_y: str = 'II') -> str:
|
| 88 |
+
"""Create scatter plot comparing two leads"""
|
| 89 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
| 90 |
+
|
| 91 |
+
if lead_x in df.columns and lead_y in df.columns:
|
| 92 |
+
ax.scatter(df[lead_x], df[lead_y], alpha=0.5, s=10, c='steelblue')
|
| 93 |
+
ax.set_xlabel(f'Lead {lead_x} Amplitude (mV)', fontsize=11)
|
| 94 |
+
ax.set_ylabel(f'Lead {lead_y} Amplitude (mV)', fontsize=11)
|
| 95 |
+
ax.set_title(f'Lead {lead_x} vs Lead {lead_y}', fontsize=13, fontweight='bold')
|
| 96 |
+
ax.grid(True, alpha=0.3)
|
| 97 |
+
|
| 98 |
+
# Add correlation coefficient
|
| 99 |
+
correlation = df[[lead_x, lead_y]].corr().iloc[0, 1]
|
| 100 |
+
ax.text(0.05, 0.95, f'Correlation: {correlation:.3f}',
|
| 101 |
+
transform=ax.transAxes, fontsize=10, verticalalignment='top',
|
| 102 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 103 |
+
else:
|
| 104 |
+
ax.text(0.5, 0.5, 'Selected leads not available',
|
| 105 |
+
ha='center', va='center', fontsize=12)
|
| 106 |
+
|
| 107 |
+
plt.tight_layout()
|
| 108 |
+
return ECGAnalyzer._fig_to_base64(fig)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def create_rolling_average(df: pd.DataFrame, leads: List[str],
|
| 112 |
+
time_col: Optional[str] = None, window: int = 100) -> str:
|
| 113 |
+
"""Create rolling average plot"""
|
| 114 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 115 |
+
|
| 116 |
+
if time_col and time_col in df.columns:
|
| 117 |
+
x_data = df[time_col]
|
| 118 |
+
x_label = 'Time (ms)' if 'time' in time_col.lower() else time_col
|
| 119 |
+
else:
|
| 120 |
+
x_data = df.index
|
| 121 |
+
x_label = 'Sample Index'
|
| 122 |
+
|
| 123 |
+
colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
|
| 124 |
+
|
| 125 |
+
for idx, lead in enumerate(leads):
|
| 126 |
+
if lead in df.columns:
|
| 127 |
+
rolling_avg = df[lead].rolling(window=window, min_periods=1).mean()
|
| 128 |
+
ax.plot(x_data, rolling_avg, label=f'Lead {lead} (MA-{window})',
|
| 129 |
+
linewidth=1.5, alpha=0.8, color=colors[idx])
|
| 130 |
+
|
| 131 |
+
ax.set_xlabel(x_label, fontsize=11)
|
| 132 |
+
ax.set_ylabel('Amplitude (mV)', fontsize=11)
|
| 133 |
+
ax.set_title(f'Rolling Average (Window={window})', fontsize=13, fontweight='bold')
|
| 134 |
+
ax.legend(loc='upper right', fontsize=9, ncol=min(4, len(leads)))
|
| 135 |
+
ax.grid(True, alpha=0.3, linestyle='--')
|
| 136 |
+
plt.tight_layout()
|
| 137 |
+
|
| 138 |
+
return ECGAnalyzer._fig_to_base64(fig)
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def create_all_visualizations(df: pd.DataFrame, leads: List[str],
|
| 142 |
+
viz_types: List[str]) -> str:
|
| 143 |
+
"""Create multiple visualizations based on selected types"""
|
| 144 |
+
html_parts = []
|
| 145 |
+
time_col = ECGAnalyzer.detect_time_column(df)
|
| 146 |
+
|
| 147 |
+
for viz_type in viz_types:
|
| 148 |
+
if viz_type == "Signal Waveform":
|
| 149 |
+
img_base64 = ECGAnalyzer.create_signal_plot(df, leads, time_col)
|
| 150 |
+
html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
|
| 151 |
+
|
| 152 |
+
elif viz_type == "Histogram":
|
| 153 |
+
img_base64 = ECGAnalyzer.create_histogram(df, leads)
|
| 154 |
+
html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
|
| 155 |
+
|
| 156 |
+
elif viz_type == "Scatter Plot":
|
| 157 |
+
# Use first two available leads for scatter plot
|
| 158 |
+
lead_x = leads[0] if len(leads) > 0 else 'I'
|
| 159 |
+
lead_y = leads[1] if len(leads) > 1 else 'II'
|
| 160 |
+
img_base64 = ECGAnalyzer.create_scatter_plot(df, lead_x, lead_y)
|
| 161 |
+
html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
|
| 162 |
+
|
| 163 |
+
elif viz_type == "Rolling Average":
|
| 164 |
+
img_base64 = ECGAnalyzer.create_rolling_average(df, leads, time_col)
|
| 165 |
+
html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
|
| 166 |
+
|
| 167 |
+
if not html_parts:
|
| 168 |
+
return '<div style="text-align:center; padding:40px;"><p>No visualizations selected</p></div>'
|
| 169 |
+
|
| 170 |
+
return '<div style="text-align:center;">' + ''.join(html_parts) + '</div>'
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _fig_to_base64(fig) -> str:
|
| 174 |
+
"""Convert matplotlib figure to base64 string"""
|
| 175 |
+
buf = io.BytesIO()
|
| 176 |
+
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 177 |
+
buf.seek(0)
|
| 178 |
+
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
| 179 |
+
plt.close(fig)
|
| 180 |
+
return img_base64
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def generate_statistics(df: pd.DataFrame, leads: List[str]) -> Dict[str, Any]:
|
| 184 |
+
"""Generate statistical summary for selected leads"""
|
| 185 |
+
stats = {}
|
| 186 |
+
for lead in leads:
|
| 187 |
+
if lead in df.columns:
|
| 188 |
+
lead_data = df[lead].dropna()
|
| 189 |
+
stats[lead] = {
|
| 190 |
+
'mean': float(lead_data.mean()),
|
| 191 |
+
'std': float(lead_data.std()),
|
| 192 |
+
'min': float(lead_data.min()),
|
| 193 |
+
'max': float(lead_data.max()),
|
| 194 |
+
'median': float(lead_data.median()),
|
| 195 |
+
'q25': float(lead_data.quantile(0.25)),
|
| 196 |
+
'q75': float(lead_data.quantile(0.75))
|
| 197 |
+
}
|
| 198 |
+
return stats
|
ecg_visualization.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import necessary libraries
|
| 2 |
+
import dash
|
| 3 |
+
from dash import html, dcc
|
| 4 |
+
from dash.dependencies import Input, Output
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from plotly.subplots import make_subplots
|
| 8 |
+
# Load the dataset
|
| 9 |
+
file_path = 'dataset/JS00001_filtered.csv' # Update this path as necessary
|
| 10 |
+
ecg_data = pd.read_csv(file_path)
|
| 11 |
+
|
| 12 |
+
# Define leads
|
| 13 |
+
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
|
| 14 |
+
|
| 15 |
+
# Initialize the Dash app
|
| 16 |
+
app = dash.Dash(__name__)
|
| 17 |
+
app.layout = html.Div(children=[
|
| 18 |
+
html.H1(children='ECG Data Dashboard', style={'textAlign': 'center', 'color': '#007BFF'}),
|
| 19 |
+
html.Div([
|
| 20 |
+
html.Label('Select ECG Lead for Analysis:'),
|
| 21 |
+
dcc.Dropdown(
|
| 22 |
+
id='lead-dropdown',
|
| 23 |
+
options=[{'label': 'All Leads', 'value': 'ALL'}] + [{'label': lead, 'value': lead} for lead in leads],
|
| 24 |
+
value='ALL' # Default value to show all leads
|
| 25 |
+
)
|
| 26 |
+
], style={'width': '30%', 'margin': '0 auto', 'padding': '20px'}),
|
| 27 |
+
dcc.Graph(id='ecg-data-visualization'),
|
| 28 |
+
], style={'padding': '20px', 'maxWidth': '1200px', 'margin': '0 auto'})
|
| 29 |
+
|
| 30 |
+
# Define callback to update the figure based on the selected lead
|
| 31 |
+
@app.callback(
|
| 32 |
+
Output('ecg-data-visualization', 'figure'),
|
| 33 |
+
[Input('lead-dropdown', 'value')]
|
| 34 |
+
)
|
| 35 |
+
def update_figure(selected_lead):
|
| 36 |
+
# Initialize the figure with subplots
|
| 37 |
+
fig = make_subplots(rows=4, cols=1,
|
| 38 |
+
subplot_titles=("ECG Signal Over Time", "Histogram of Signal Amplitudes",
|
| 39 |
+
"Scatter Plot: Lead I vs Lead II", "Rolling Average"),
|
| 40 |
+
vertical_spacing=0.1,
|
| 41 |
+
specs=[[{"type": "scatter"}], [{"type": "histogram"}], [{"type": "scatter"}], [{"type": "scatter"}]])
|
| 42 |
+
|
| 43 |
+
# Conditionally display either all leads or just the selected lead
|
| 44 |
+
if selected_lead == 'ALL':
|
| 45 |
+
# Show all leads
|
| 46 |
+
for lead in leads:
|
| 47 |
+
fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data[lead], mode='lines', name=f'Lead {lead}'), row=1, col=1)
|
| 48 |
+
fig.add_trace(go.Histogram(x=ecg_data[lead], name=f'Lead {lead}', opacity=0.75), row=2, col=1)
|
| 49 |
+
else:
|
| 50 |
+
# Show only the selected lead
|
| 51 |
+
ecg_data['RollingAvg'] = ecg_data[selected_lead].rolling(window=100).mean()
|
| 52 |
+
fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data[selected_lead], mode='lines', name=f'Lead {selected_lead}'), row=1, col=1)
|
| 53 |
+
fig.add_trace(go.Histogram(x=ecg_data[selected_lead], name=f'Lead {selected_lead}', opacity=0.75), row=2, col=1)
|
| 54 |
+
fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data['RollingAvg'], mode='lines', name=f'Rolling Average: Lead {selected_lead}'), row=4, col=1)
|
| 55 |
+
|
| 56 |
+
# Common settings for all cases
|
| 57 |
+
fig.update_traces(opacity=0.75, bingroup=1, row=2, col=1)
|
| 58 |
+
fig.update_layout(barmode='overlay')
|
| 59 |
+
fig.add_trace(go.Scatter(x=ecg_data['I'], y=ecg_data['II'], mode='markers', name='Lead I vs Lead II'), row=3, col=1)
|
| 60 |
+
|
| 61 |
+
# Update the figure layout
|
| 62 |
+
fig.update_layout(height=1600, title_text="Comprehensive ECG Data Analysis", showlegend=True)
|
| 63 |
+
return fig
|
| 64 |
+
|
| 65 |
+
# Run the app
|
| 66 |
+
if __name__ == '__main__':
|
| 67 |
+
app.run_server(debug=True)
|
examples/main.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deduplication import find_near_duplicates
|
| 2 |
+
from featurizer import custom_featurizer
|
| 3 |
+
from issues import find_issues
|
| 4 |
+
from pipeline import make_step, run_pipeline
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
|
| 8 |
+
bar = tqdm(total=100, leave=True)
|
| 9 |
+
|
| 10 |
+
steps = [
|
| 11 |
+
make_step(find_near_duplicates, name="dedup")(progress=bar),
|
| 12 |
+
make_step(custom_featurizer, name="featurize")(
|
| 13 |
+
label=None, # optional; only used to drop NaN label rows
|
| 14 |
+
nan_strategy="impute",
|
| 15 |
+
on_pipeline_error="drop",
|
| 16 |
+
progress=bar
|
| 17 |
+
),
|
| 18 |
+
make_step(find_issues, name="find_label_issues")(label="HARDSHIP_INDEX", progress=bar)
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
df = pd.read_csv("./data/Lisette.csv")
|
| 22 |
+
results = run_pipeline(steps, df=df)
|
| 23 |
+
|
| 24 |
+
bar.close()
|
| 25 |
+
print(results)
|
| 26 |
+
|
| 27 |
+
|
pipeline/__init__.py
ADDED
|
File without changes
|
pipeline/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
pipeline/__pycache__/deduplication.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
pipeline/__pycache__/featurizer.cpython-311.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
pipeline/__pycache__/issues.cpython-311.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
pipeline/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
pipeline/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
pipeline/__pycache__/utils_cool.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
pipeline/deduplication.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from time import perf_counter
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from cleanlab import Datalab
|
| 9 |
+
from cleanlab.datalab.internal.issue_manager.duplicate import NearDuplicateIssueManager
|
| 10 |
+
from scipy.sparse import issparse
|
| 11 |
+
from scipy.special import comb
|
| 12 |
+
from sklearn.base import TransformerMixin
|
| 13 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
|
| 16 |
+
from .utils_cool import PhaseProgress, _ensure_dense32, choose_k
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# =========================
|
| 20 |
+
# Core near-duplicate finder
|
| 21 |
+
# =========================
|
| 22 |
+
def find_near_duplicates(
|
| 23 |
+
df: pd.DataFrame,
|
| 24 |
+
*,
|
| 25 |
+
# Cleanlab params
|
| 26 |
+
metric: str = "cosine",
|
| 27 |
+
threshold: float = 0.13,
|
| 28 |
+
k: Optional[int] = None,
|
| 29 |
+
# Vectorizer / features
|
| 30 |
+
vectorizer: Optional[TransformerMixin] = None,
|
| 31 |
+
force_dense: bool = False, # set True if your cleanlab version needs dense
|
| 32 |
+
# Behavior
|
| 33 |
+
verbose: bool = False,
|
| 34 |
+
# Progress: either a tqdm object or a callable phase,p in [0,1]
|
| 35 |
+
progress: Optional[Any] = None,
|
| 36 |
+
) -> Tuple[Optional[pd.DataFrame], Dict]:
|
| 37 |
+
"""
|
| 38 |
+
Detect near-duplicates using Cleanlab's NearDuplicateIssueManager.
|
| 39 |
+
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
df : DataFrame
|
| 43 |
+
Rows to analyze. If no vectorizer is passed, all columns are joined as strings for TF-IDF.
|
| 44 |
+
metric : {"cosine", "euclidean", "manhattan"}
|
| 45 |
+
Distance metric for kNN graph.
|
| 46 |
+
threshold : float
|
| 47 |
+
Near-duplicate radius is based on threshold × median NN distance (internal to Cleanlab).
|
| 48 |
+
k : int or None
|
| 49 |
+
Neighborhood size. If None, uses sqrt(N) clipped to [5, 50] and ≤ N-1.
|
| 50 |
+
vectorizer : sklearn Transformer
|
| 51 |
+
Any transformer with fit_transform/transform. If None, uses TF-IDF (float32).
|
| 52 |
+
force_dense : bool
|
| 53 |
+
If True, densify features before passing to Cleanlab.
|
| 54 |
+
verbose : bool
|
| 55 |
+
Print timing breakdown.
|
| 56 |
+
progress : tqdm or Callable[[str, float], None]
|
| 57 |
+
Phase-aware progress reporting.
|
| 58 |
+
|
| 59 |
+
Returns
|
| 60 |
+
-------
|
| 61 |
+
(output_df, stats)
|
| 62 |
+
output_df : DataFrame after deduplication
|
| 63 |
+
stats : dict of counts/timings/params
|
| 64 |
+
"""
|
| 65 |
+
# Progress setup (one bar per call unless user provided one)
|
| 66 |
+
local_bar = None
|
| 67 |
+
pp = None
|
| 68 |
+
if progress is None:
|
| 69 |
+
local_bar = tqdm(total=100, leave=True)
|
| 70 |
+
pp = PhaseProgress(local_bar, weights={"vectorize": .2, "find_issues": .7, "grouping": .1})
|
| 71 |
+
else:
|
| 72 |
+
if hasattr(progress, "set_description") and hasattr(progress, "update"):
|
| 73 |
+
# treat given tqdm as the bar
|
| 74 |
+
if not hasattr(progress, "_last_val"):
|
| 75 |
+
progress._last_val = 0
|
| 76 |
+
pp = PhaseProgress(progress, weights={"vectorize": .2, "find_issues": .7, "grouping": .1})
|
| 77 |
+
|
| 78 |
+
timings: Dict[str, float] = {}
|
| 79 |
+
t0 = perf_counter()
|
| 80 |
+
N = int(len(df))
|
| 81 |
+
|
| 82 |
+
# --- Vectorize ---
|
| 83 |
+
pp and pp.start("vectorize", extra={"N": N})
|
| 84 |
+
t_vec0 = perf_counter()
|
| 85 |
+
text_series = df.astype(str).agg(" ".join, axis=1)
|
| 86 |
+
|
| 87 |
+
if vectorizer is None:
|
| 88 |
+
vectorizer = TfidfVectorizer(dtype=np.float32)
|
| 89 |
+
|
| 90 |
+
if hasattr(vectorizer, "fit_transform"):
|
| 91 |
+
X = vectorizer.fit_transform(text_series.tolist())
|
| 92 |
+
elif hasattr(vectorizer, "transform"):
|
| 93 |
+
X = vectorizer.transform(text_series.tolist())
|
| 94 |
+
else:
|
| 95 |
+
raise TypeError("`vectorizer` must implement fit_transform or transform.")
|
| 96 |
+
|
| 97 |
+
if force_dense:
|
| 98 |
+
X = _ensure_dense32(X)
|
| 99 |
+
|
| 100 |
+
timings["vectorize"] = perf_counter() - t_vec0
|
| 101 |
+
pp and pp.tick_abs("vectorize", 1.0)
|
| 102 |
+
pp and pp.end("vectorize")
|
| 103 |
+
|
| 104 |
+
# --- Cleanlab duplicate finder ---
|
| 105 |
+
pp and pp.start("find_issues", extra={"metric": metric})
|
| 106 |
+
t_cl0 = perf_counter()
|
| 107 |
+
|
| 108 |
+
if k is None:
|
| 109 |
+
k = choose_k(N)
|
| 110 |
+
|
| 111 |
+
lab = Datalab(data={"__row__": list(range(N))})
|
| 112 |
+
ndm = NearDuplicateIssueManager(datalab=lab, metric=metric, threshold=threshold, k=k)
|
| 113 |
+
|
| 114 |
+
pp and pp.tick_abs("find_issues", 0.0, extra={"k": k})
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
ndm.find_issues(features=X)
|
| 118 |
+
except Exception:
|
| 119 |
+
if issparse(X):
|
| 120 |
+
X_dense = _ensure_dense32(X)
|
| 121 |
+
ndm.find_issues(features=X_dense) # retry dense
|
| 122 |
+
else:
|
| 123 |
+
raise
|
| 124 |
+
|
| 125 |
+
pp and pp.tick_abs("find_issues", 1.0)
|
| 126 |
+
timings["cleanlab_find_issues"] = perf_counter() - t_cl0
|
| 127 |
+
pp and pp.end("find_issues", extra={"k": k})
|
| 128 |
+
|
| 129 |
+
near_dup_sets: List[List[int]] = getattr(ndm, "near_duplicate_sets", []) or []
|
| 130 |
+
|
| 131 |
+
# --- Representatives & output ---
|
| 132 |
+
pp and pp.start("grouping")
|
| 133 |
+
t_out0 = perf_counter()
|
| 134 |
+
|
| 135 |
+
# Keep smallest index as representative; skip any empty groups defensively
|
| 136 |
+
reps = [int(np.min(g)) for g in near_dup_sets if np.size(g) > 0]
|
| 137 |
+
in_any = {int(i) for g in near_dup_sets for i in np.asarray(g).ravel()}
|
| 138 |
+
keep_set = set(reps) | (set(range(N)) - in_any)
|
| 139 |
+
|
| 140 |
+
out_df = None
|
| 141 |
+
if N > 0:
|
| 142 |
+
keep_mask = np.zeros(N, dtype=bool)
|
| 143 |
+
if keep_set:
|
| 144 |
+
keep_mask[list(keep_set)] = True
|
| 145 |
+
out_df = df.iloc[np.where(keep_mask)[0]].copy()
|
| 146 |
+
else:
|
| 147 |
+
out_df = df.copy()
|
| 148 |
+
|
| 149 |
+
# Stats
|
| 150 |
+
n_groups = sum(1 for g in near_dup_sets if np.size(g) > 0)
|
| 151 |
+
group_sizes = [int(np.size(g)) for g in near_dup_sets if np.size(g) > 0]
|
| 152 |
+
n_flagged = sum(max(0, s - 1) for s in group_sizes) # rows we'd drop if keeping 1 per group
|
| 153 |
+
n_pairs = int(sum(comb(s, 2, exact=True) for s in group_sizes))
|
| 154 |
+
|
| 155 |
+
timings["groups_and_output"] = perf_counter() - t_out0
|
| 156 |
+
total_time = perf_counter() - t0
|
| 157 |
+
|
| 158 |
+
stats = {
|
| 159 |
+
"n_rows_before_dedup": N,
|
| 160 |
+
"n_near_dupe_pairs": n_pairs,
|
| 161 |
+
"n_groups": n_groups,
|
| 162 |
+
"avg_group_size": float(np.mean(group_sizes)) if group_sizes else 0.0,
|
| 163 |
+
"max_group_size": max(group_sizes) if group_sizes else 0,
|
| 164 |
+
"n_rows_flagged_duplicates": n_flagged,
|
| 165 |
+
"n_rows_after_dedup": int(len(out_df)) if out_df is not None else N,
|
| 166 |
+
"metric": metric,
|
| 167 |
+
"threshold": threshold,
|
| 168 |
+
"k": int(k),
|
| 169 |
+
"timings": timings,
|
| 170 |
+
"total_time_sec": total_time,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if verbose:
|
| 174 |
+
print(f"[timing] TOTAL: {total_time*1000:.1f} ms")
|
| 175 |
+
for k_, v in timings.items():
|
| 176 |
+
print(f" - {k_}: {v*1000:.1f} ms")
|
| 177 |
+
|
| 178 |
+
pp and pp.tick_abs("grouping", 1.0, extra={"groups": n_groups, "pairs": n_pairs})
|
| 179 |
+
pp and pp.end("grouping")
|
| 180 |
+
|
| 181 |
+
if local_bar is not None:
|
| 182 |
+
pp.close()
|
| 183 |
+
|
| 184 |
+
# Return the standardized pair expected by your pipeline
|
| 185 |
+
return out_df, stats
|
pipeline/featurizer.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 8 |
+
from sklearn.impute import SimpleImputer
|
| 9 |
+
from sklearn.pipeline import Pipeline
|
| 10 |
+
from sklearn.preprocessing import FunctionTransformer, StandardScaler
|
| 11 |
+
|
| 12 |
+
from .utils_cool import PhaseProgress
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _safe_apply_pipeline(
|
| 16 |
+
group_name: str,
|
| 17 |
+
df_in: pd.DataFrame,
|
| 18 |
+
cols: List[str],
|
| 19 |
+
pipeline: Pipeline,
|
| 20 |
+
*,
|
| 21 |
+
drop_on_error: bool,
|
| 22 |
+
warnings: List[str],
|
| 23 |
+
max_new_cols: int = 2000,
|
| 24 |
+
) -> pd.DataFrame:
|
| 25 |
+
"""
|
| 26 |
+
Apply a user pipeline to df_in[cols] safely.
|
| 27 |
+
- If success and returns DataFrame with same index: replace group columns with returned columns (prefixed).
|
| 28 |
+
- If success and returns ndarray/sparse: if width <= max_new_cols, expand with auto names; else drop group.
|
| 29 |
+
- If failure and drop_on_error: drop group columns and record a warning.
|
| 30 |
+
"""
|
| 31 |
+
Xsub = df_in[cols]
|
| 32 |
+
try:
|
| 33 |
+
out = pipeline.fit_transform(Xsub)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
if drop_on_error:
|
| 36 |
+
warnings.append(
|
| 37 |
+
f"[custom_featurizer] dropped '{group_name}' columns ({len(cols)}): {cols[:6]}{'...' if len(cols)>6 else ''} "
|
| 38 |
+
f"reason={type(e).__name__}: {str(e)[:180]}"
|
| 39 |
+
)
|
| 40 |
+
return df_in.drop(columns=cols)
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
# If pipeline returns a DataFrame -> use its columns (prefixed)
|
| 44 |
+
if isinstance(out, pd.DataFrame):
|
| 45 |
+
out_df = out.copy()
|
| 46 |
+
# ensure index alignment
|
| 47 |
+
out_df.index = df_in.index
|
| 48 |
+
# prefix to avoid collisions
|
| 49 |
+
out_df.columns = [f"{group_name}__{c}" for c in out_df.columns]
|
| 50 |
+
df_out = df_in.drop(columns=cols).join(out_df)
|
| 51 |
+
return df_out
|
| 52 |
+
|
| 53 |
+
# If ndarray / sparse matrix -> expand to columns cautiously
|
| 54 |
+
try:
|
| 55 |
+
import numpy as _np
|
| 56 |
+
import scipy.sparse as _sp
|
| 57 |
+
if _sp.issparse(out):
|
| 58 |
+
out = out.toarray()
|
| 59 |
+
out = _np.asarray(out)
|
| 60 |
+
n_rows, n_cols = out.shape[0], (out.shape[1] if out.ndim == 2 else 1)
|
| 61 |
+
if n_rows != len(df_in):
|
| 62 |
+
raise ValueError(f"Pipeline for '{group_name}' returned {n_rows} rows; expected {len(df_in)}.")
|
| 63 |
+
if n_cols > max_new_cols:
|
| 64 |
+
warnings.append(
|
| 65 |
+
f"[custom_featurizer] '{group_name}' produced {n_cols} columns (> {max_new_cols}); dropping this group to avoid explosion."
|
| 66 |
+
)
|
| 67 |
+
return df_in.drop(columns=cols)
|
| 68 |
+
if out.ndim == 1:
|
| 69 |
+
out = out.reshape(-1, 1)
|
| 70 |
+
n_cols = 1
|
| 71 |
+
new_cols = [f"{group_name}__f{i}" for i in range(n_cols)]
|
| 72 |
+
out_df = pd.DataFrame(out, index=df_in.index, columns=new_cols)
|
| 73 |
+
df_out = df_in.drop(columns=cols).join(out_df)
|
| 74 |
+
return df_out
|
| 75 |
+
except Exception as e:
|
| 76 |
+
if drop_on_error:
|
| 77 |
+
warnings.append(
|
| 78 |
+
f"[custom_featurizer] failed to materialize output for '{group_name}'; dropping group. "
|
| 79 |
+
f"reason={type(e).__name__}: {str(e)[:180]}"
|
| 80 |
+
)
|
| 81 |
+
return df_in.drop(columns=cols)
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def custom_featurizer(
|
| 86 |
+
df: pd.DataFrame,
|
| 87 |
+
*,
|
| 88 |
+
# optional label just for cleaning rows with NaN in label; not returned as y
|
| 89 |
+
label: Optional[str] = None,
|
| 90 |
+
|
| 91 |
+
# user-provided pipelines (override defaults for that group)
|
| 92 |
+
numeric_pipeline: Optional[Pipeline] = None,
|
| 93 |
+
low_card_pipeline: Optional[Pipeline] = None,
|
| 94 |
+
text_pipeline: Optional[Pipeline] = None,
|
| 95 |
+
|
| 96 |
+
# defaults (used ONLY if the corresponding pipeline is NOT provided)
|
| 97 |
+
numeric_scale: bool = True, # scale numeric after impute
|
| 98 |
+
text_lowercase: bool = True, # lower+strip text
|
| 99 |
+
max_ohe_cardinality: int = 50, # threshold to split low-card vs text
|
| 100 |
+
|
| 101 |
+
# NaN handling for DEFAULTS ONLY (ignored if a custom pipeline is provided for that group)
|
| 102 |
+
nan_strategy: str = "impute", # "impute" or "drop"
|
| 103 |
+
|
| 104 |
+
# failure policy for user pipelines
|
| 105 |
+
on_pipeline_error: str = "drop", # "drop" -> drop group; "raise" -> bubble error
|
| 106 |
+
|
| 107 |
+
# control expansion when user pipelines return big matrices
|
| 108 |
+
max_new_cols_per_group: int = 2000,
|
| 109 |
+
|
| 110 |
+
# progress / logging
|
| 111 |
+
progress: Optional[Any] = None, # pass a tqdm here; None -> auto-create
|
| 112 |
+
|
| 113 |
+
# logging
|
| 114 |
+
verbose: bool = False,
|
| 115 |
+
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
| 116 |
+
"""
|
| 117 |
+
Featurize a mixed DataFrame while keeping the output as a DataFrame for downstream steps.
|
| 118 |
+
|
| 119 |
+
Overview
|
| 120 |
+
--------
|
| 121 |
+
Columns are split into three groups:
|
| 122 |
+
• numeric (including boolean)
|
| 123 |
+
• low-cardinality categoricals (nunique ≤ `max_ohe_cardinality`)
|
| 124 |
+
• text/high-cardinality (nunique > `max_ohe_cardinality`)
|
| 125 |
+
|
| 126 |
+
You may pass custom sklearn `Pipeline`s per group. If a user pipeline raises and
|
| 127 |
+
`on_pipeline_error="drop"`, that entire group of columns is dropped and a short
|
| 128 |
+
warning is recorded in `stats["warnings"]` (the step does not crash). If `"raise"`,
|
| 129 |
+
the error is propagated.
|
| 130 |
+
|
| 131 |
+
Defaults when pipelines are NOT provided
|
| 132 |
+
---------------------------------------
|
| 133 |
+
• Numeric (when `numeric_pipeline is None`)
|
| 134 |
+
- NaNs: if `nan_strategy="impute"`, apply `SimpleImputer(strategy="median")`;
|
| 135 |
+
if `"drop"`, rows with NaNs in numeric columns are dropped BEFORE this step.
|
| 136 |
+
- Scaling: if `numeric_scale=True`, apply `StandardScaler(with_mean=False)`.
|
| 137 |
+
- Columns: replaced in place (same column names remain; values become numeric/float).
|
| 138 |
+
• Low-cardinality categoricals (when `low_card_pipeline is None`)
|
| 139 |
+
- NaNs: if `nan_strategy="impute"`, apply `SimpleImputer(strategy="most_frequent")`;
|
| 140 |
+
if `"drop"`, rows with NaNs in these columns are dropped BEFORE this step.
|
| 141 |
+
- Encoding: **no one-hot by default** (values stay as cleaned strings/categories).
|
| 142 |
+
If you want encodings, pass your own `low_card_pipeline` (e.g., OneHotEncoder/CatBoostEncoder).
|
| 143 |
+
- Columns: preserved (same names).
|
| 144 |
+
• Text / high-cardinality (when `text_pipeline is None`)
|
| 145 |
+
- Build a tiny pipeline:
|
| 146 |
+
concat selected text cols → `TfidfVectorizer(dtype=float32, lowercase=text_lowercase)`
|
| 147 |
+
- Output: numeric TF-IDF features. Original text columns are **replaced** by new
|
| 148 |
+
columns named `txt__f0`, `txt__f1`, … (or `txt__<col>` if your pipeline returns a DataFrame).
|
| 149 |
+
- Feature explosion guard: if the produced matrix has more than `max_new_cols_per_group` columns,
|
| 150 |
+
the entire text group is dropped and a warning is recorded.
|
| 151 |
+
- If TF-IDF fails (e.g., empty vocabulary on tiny data), the text group is dropped with a warning.
|
| 152 |
+
|
| 153 |
+
NaN handling
|
| 154 |
+
------------
|
| 155 |
+
`nan_strategy` applies only to groups using the **default** pipeline:
|
| 156 |
+
- "impute": impute NaNs (as above)
|
| 157 |
+
- "drop" : drop rows containing NaNs in any default-handled feature column **before**
|
| 158 |
+
transforming. For groups with a **custom** pipeline, NaN handling is your pipeline’s responsibility.
|
| 159 |
+
|
| 160 |
+
Parameters
|
| 161 |
+
----------
|
| 162 |
+
df : pd.DataFrame
|
| 163 |
+
Input table. If `label` is provided, rows with NaN in `label` are dropped first.
|
| 164 |
+
label : Optional[str], default=None
|
| 165 |
+
Name of the target column (only used to drop NaN labels). Not returned as `y`.
|
| 166 |
+
numeric_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
|
| 167 |
+
Custom pipeline for numeric/boolean columns. If provided, `nan_strategy` is ignored for this group.
|
| 168 |
+
By default: median imputation (+ optional scaling) and columns are preserved.
|
| 169 |
+
low_card_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
|
| 170 |
+
Custom pipeline for low-card categorical columns. If provided, `nan_strategy` is ignored for this group.
|
| 171 |
+
By default: most-frequent imputation only; **no encoding**; columns are preserved.
|
| 172 |
+
text_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
|
| 173 |
+
Custom pipeline for text/high-card columns. If provided, it replaces the text columns with whatever
|
| 174 |
+
it outputs (DataFrame → prefixed columns; array/sparse → `txt__f*` columns). If not provided, the
|
| 175 |
+
built-in concat+TF-IDF is used (see defaults above).
|
| 176 |
+
numeric_scale : bool, default=True
|
| 177 |
+
Applies `StandardScaler(with_mean=False)` to numeric columns in the default numeric pipeline.
|
| 178 |
+
text_lowercase : bool, default=True
|
| 179 |
+
Forwarded to the default `TfidfVectorizer(lowercase=...)`.
|
| 180 |
+
max_ohe_cardinality : int, default=50
|
| 181 |
+
Threshold to classify non-numeric columns as low-card (≤ threshold) vs text/high-card (> threshold).
|
| 182 |
+
nan_strategy : {"impute","drop"}, default="impute"
|
| 183 |
+
Strategy for **default** pipelines only (custom pipelines manage their own NaNs).
|
| 184 |
+
on_pipeline_error : {"drop","raise"}, default="drop"
|
| 185 |
+
If a **user** pipeline raises: "drop" → drop that group and record a warning; "raise" → propagate.
|
| 186 |
+
max_new_cols_per_group : int, default=2000
|
| 187 |
+
Upper bound on the number of columns a group is allowed to add (applies to array/sparse outputs).
|
| 188 |
+
If exceeded, the group is dropped and a warning is recorded.
|
| 189 |
+
progress : Optional[tqdm], default=None
|
| 190 |
+
Phase-aware progress bar (clean → split → numeric → low_card → text → finalize).
|
| 191 |
+
If None, a local bar is created and closed automatically.
|
| 192 |
+
verbose : bool, default=False
|
| 193 |
+
When True, prints recorded warnings (in addition to returning them in `stats["warnings"]`).
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
output_df : pd.DataFrame
|
| 198 |
+
Transformed DataFrame ready for the next pipeline step. Numeric/low-card columns are
|
| 199 |
+
updated in place by defaults; text columns are replaced by TF-IDF features when using
|
| 200 |
+
the default text path.
|
| 201 |
+
stats : Dict[str, Any]
|
| 202 |
+
Minimal metadata:
|
| 203 |
+
- "warnings": List[str]
|
| 204 |
+
- "cols": {"numeric": [...], "low_card": [...], "text": [...]}
|
| 205 |
+
- "n_rows_before": int
|
| 206 |
+
- "n_rows_after": int
|
| 207 |
+
|
| 208 |
+
Notes
|
| 209 |
+
-----
|
| 210 |
+
• If a user pipeline returns a DataFrame, columns are prefixed with the group name to avoid collisions.
|
| 211 |
+
If it returns an array/sparse matrix, columns are auto-named `"{group}__f{i}"`.
|
| 212 |
+
• The feature explosion guard applies after transformation; if you often hit it with text, consider passing
|
| 213 |
+
your own `TfidfVectorizer(max_features=...)` in `text_pipeline` to cap the width proactively.
|
| 214 |
+
"""
|
| 215 |
+
# progress setup
|
| 216 |
+
local_bar = None
|
| 217 |
+
pp = None
|
| 218 |
+
if progress is None:
|
| 219 |
+
from tqdm.auto import tqdm # local import to keep module lightweight
|
| 220 |
+
local_bar = tqdm(total=100, leave=True)
|
| 221 |
+
pp = PhaseProgress(local_bar, weights={
|
| 222 |
+
"clean": .10, "split": .10, "numeric": .30, "low_card": .25, "text": .20, "finalize": .05
|
| 223 |
+
})
|
| 224 |
+
elif hasattr(progress, "set_description") and hasattr(progress, "update"):
|
| 225 |
+
if not hasattr(progress, "_last_val"):
|
| 226 |
+
progress._last_val = 0
|
| 227 |
+
pp = PhaseProgress(progress, weights={
|
| 228 |
+
"clean": .10, "split": .10, "numeric": .30, "low_card": .25, "text": .20, "finalize": .05
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
warnings: List[str] = []
|
| 232 |
+
|
| 233 |
+
# 1) drop rows with NaN label if present
|
| 234 |
+
n_before = len(df)
|
| 235 |
+
pp and pp.start("clean", extra={"N": n_before})
|
| 236 |
+
if label is not None and label not in df.columns:
|
| 237 |
+
# ensure the bar is closed even on error
|
| 238 |
+
if local_bar is not None:
|
| 239 |
+
pp.close()
|
| 240 |
+
raise KeyError(f"Target column '{label}' not in DataFrame.")
|
| 241 |
+
if label is not None:
|
| 242 |
+
df = df.dropna(subset=[label]).reset_index(drop=True)
|
| 243 |
+
pp and pp.tick_abs("clean", 1.0, extra={"N": len(df)})
|
| 244 |
+
pp and pp.end("clean")
|
| 245 |
+
|
| 246 |
+
# 2) split columns
|
| 247 |
+
pp and pp.start("split")
|
| 248 |
+
Xf = df if label is None else df.drop(columns=[label])
|
| 249 |
+
numeric_cols: List[str] = Xf.select_dtypes(include=[np.number, "boolean"]).columns.tolist()
|
| 250 |
+
non_numeric = [c for c in Xf.columns if c not in numeric_cols]
|
| 251 |
+
|
| 252 |
+
low_card_cols: List[str] = []
|
| 253 |
+
text_cols: List[str] = []
|
| 254 |
+
for c in non_numeric:
|
| 255 |
+
nunq = Xf[c].nunique(dropna=True)
|
| 256 |
+
(low_card_cols if nunq <= max_ohe_cardinality else text_cols).append(c)
|
| 257 |
+
pp and pp.tick_abs("split", 1.0, extra={
|
| 258 |
+
"num": len(numeric_cols), "low": len(low_card_cols), "txt": len(text_cols)
|
| 259 |
+
})
|
| 260 |
+
pp and pp.end("split")
|
| 261 |
+
|
| 262 |
+
# 3) NaN strategy for default groups (custom pipelines assumed to handle NaNs)
|
| 263 |
+
if nan_strategy not in {"impute", "drop"}:
|
| 264 |
+
if local_bar is not None:
|
| 265 |
+
pp.close()
|
| 266 |
+
raise ValueError("nan_strategy must be 'impute' or 'drop'")
|
| 267 |
+
if nan_strategy == "drop":
|
| 268 |
+
cols_to_check = []
|
| 269 |
+
if numeric_cols and numeric_pipeline is None:
|
| 270 |
+
cols_to_check += numeric_cols
|
| 271 |
+
if low_card_cols and low_card_pipeline is None:
|
| 272 |
+
cols_to_check += low_card_cols
|
| 273 |
+
if text_cols and text_pipeline is None:
|
| 274 |
+
cols_to_check += text_cols
|
| 275 |
+
if cols_to_check:
|
| 276 |
+
mask = ~Xf[cols_to_check].isna().any(axis=1)
|
| 277 |
+
df = df.loc[mask].reset_index(drop=True)
|
| 278 |
+
Xf = df if label is None else df.drop(columns=[label])
|
| 279 |
+
|
| 280 |
+
# 4) numeric
|
| 281 |
+
pp and pp.start("numeric", extra={"cols": len(numeric_cols)})
|
| 282 |
+
if numeric_cols:
|
| 283 |
+
if numeric_pipeline is not None:
|
| 284 |
+
df = _safe_apply_pipeline(
|
| 285 |
+
"num", df, numeric_cols, numeric_pipeline,
|
| 286 |
+
drop_on_error=(on_pipeline_error == "drop"),
|
| 287 |
+
warnings=warnings,
|
| 288 |
+
max_new_cols=max_new_cols_per_group,
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
if nan_strategy == "impute":
|
| 292 |
+
imputed = SimpleImputer(strategy="median").fit_transform(df[numeric_cols])
|
| 293 |
+
df.loc[:, numeric_cols] = imputed
|
| 294 |
+
if numeric_scale:
|
| 295 |
+
scaler = StandardScaler(with_mean=False)
|
| 296 |
+
scaled = scaler.fit_transform(df[numeric_cols].astype(float, copy=False))
|
| 297 |
+
df.loc[:, numeric_cols] = np.asarray(scaled)
|
| 298 |
+
pp and pp.tick_abs("numeric", 1.0)
|
| 299 |
+
pp and pp.end("numeric", extra={"cols": len(numeric_cols)})
|
| 300 |
+
|
| 301 |
+
# 5) low-card categoricals
|
| 302 |
+
pp and pp.start("low_card", extra={"cols": len(low_card_cols)})
|
| 303 |
+
if low_card_cols:
|
| 304 |
+
if low_card_pipeline is not None:
|
| 305 |
+
df = _safe_apply_pipeline(
|
| 306 |
+
"cat", df, low_card_cols, low_card_pipeline,
|
| 307 |
+
drop_on_error=(on_pipeline_error == "drop"),
|
| 308 |
+
warnings=warnings,
|
| 309 |
+
max_new_cols=max_new_cols_per_group,
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
if nan_strategy == "impute":
|
| 313 |
+
df.loc[:, low_card_cols] = SimpleImputer(strategy="most_frequent").fit_transform(df[low_card_cols])
|
| 314 |
+
pp and pp.tick_abs("low_card", 1.0)
|
| 315 |
+
pp and pp.end("low_card", extra={"cols": len(low_card_cols)})
|
| 316 |
+
|
| 317 |
+
# 6) text
|
| 318 |
+
pp and pp.start("text", extra={"cols": len(text_cols)})
|
| 319 |
+
if text_cols:
|
| 320 |
+
if text_pipeline is None:
|
| 321 |
+
def _concat_cols(Xframe: pd.DataFrame):
|
| 322 |
+
return Xframe.fillna("").astype(str).agg(" ".join, axis=1).values
|
| 323 |
+
|
| 324 |
+
text_pipeline = Pipeline([
|
| 325 |
+
("concat", FunctionTransformer(_concat_cols, validate=False)),
|
| 326 |
+
("tfidf", TfidfVectorizer(
|
| 327 |
+
dtype=np.float32,
|
| 328 |
+
lowercase=bool(text_lowercase),
|
| 329 |
+
)),
|
| 330 |
+
])
|
| 331 |
+
|
| 332 |
+
df = _safe_apply_pipeline(
|
| 333 |
+
"txt", df, text_cols, text_pipeline,
|
| 334 |
+
drop_on_error=(on_pipeline_error == "drop"),
|
| 335 |
+
warnings=warnings,
|
| 336 |
+
max_new_cols=max_new_cols_per_group,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
pp and pp.tick_abs("text", 1.0)
|
| 340 |
+
pp and pp.end("text", extra={"cols": len(text_cols)})
|
| 341 |
+
|
| 342 |
+
# 7) finalize
|
| 343 |
+
pp and pp.start("finalize")
|
| 344 |
+
stats: Dict[str, Any] = {
|
| 345 |
+
"warnings": warnings,
|
| 346 |
+
"cols": {"numeric": numeric_cols, "low_card": low_card_cols, "text": text_cols},
|
| 347 |
+
"n_rows_before": n_before,
|
| 348 |
+
"n_rows_after": len(df),
|
| 349 |
+
}
|
| 350 |
+
pp and pp.tick_abs("finalize", 1.0, extra={"warn": len(warnings)})
|
| 351 |
+
pp and pp.end("finalize", extra={"warn": len(warnings)})
|
| 352 |
+
|
| 353 |
+
if verbose and warnings:
|
| 354 |
+
print("\n".join(warnings))
|
| 355 |
+
if local_bar is not None:
|
| 356 |
+
pp.close()
|
| 357 |
+
|
| 358 |
+
return df, stats
|
pipeline/issues.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from scipy.sparse import issparse
|
| 8 |
+
from sklearn.linear_model import LogisticRegression, SGDRegressor
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
|
| 11 |
+
from .utils_cool import PhaseProgress, _ensure_dense32, _infer_task
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def find_issues(
|
| 15 |
+
df: pd.DataFrame,
|
| 16 |
+
*,
|
| 17 |
+
label: str,
|
| 18 |
+
task: Optional[str] = None, # "classification" | "regression"; if None we infer
|
| 19 |
+
model: Optional[Any] = None, # sklearn estimator; default chosen by task
|
| 20 |
+
progress: Optional[Any] = None, # tqdm bar; None -> auto-create
|
| 21 |
+
verbose: bool = False,
|
| 22 |
+
) -> Tuple[Optional[pd.DataFrame], Dict]:
|
| 23 |
+
"""
|
| 24 |
+
Detect label issues using Cleanlab's CleanLearning.
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
df : DataFrame
|
| 29 |
+
Input table containing features and a label column.
|
| 30 |
+
label : str
|
| 31 |
+
Name of the label column. Rows with NaN in label are dropped.
|
| 32 |
+
task : {"classification","regression"}, optional
|
| 33 |
+
If not provided, inferred from y.
|
| 34 |
+
model : sklearn estimator, optional
|
| 35 |
+
If not provided, defaults to LogisticRegression or SGDRegressor.
|
| 36 |
+
progress : tqdm, optional
|
| 37 |
+
Phase-aware progress bar. If None, a local bar is created and closed.
|
| 38 |
+
verbose : bool, default False
|
| 39 |
+
Print warnings/timings in addition to returning them in stats.
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
(df_out, stats)
|
| 44 |
+
df_out : DataFrame with label-issues optionally removed (or original df if `remove_issues=False`).
|
| 45 |
+
stats : dict with minimal metadata and counts.
|
| 46 |
+
"""
|
| 47 |
+
# ---- progress setup ----
|
| 48 |
+
local_bar = None
|
| 49 |
+
pp = None
|
| 50 |
+
if progress is None:
|
| 51 |
+
local_bar = tqdm(total=100, leave=True)
|
| 52 |
+
pp = PhaseProgress(local_bar, weights={"clean": .15, "cleanlab": .75, "finalize": .10})
|
| 53 |
+
elif hasattr(progress, "set_description") and hasattr(progress, "update"):
|
| 54 |
+
if not hasattr(progress, "_last_val"):
|
| 55 |
+
progress._last_val = 0
|
| 56 |
+
pp = PhaseProgress(progress, weights={"clean": .15, "cleanlab": .75, "finalize": .10})
|
| 57 |
+
|
| 58 |
+
warnings: List[str] = []
|
| 59 |
+
n_before = len(df)
|
| 60 |
+
|
| 61 |
+
# ---- clean: ensure label exists, drop NaNs in label ----
|
| 62 |
+
pp and pp.start("clean", extra={"N": n_before})
|
| 63 |
+
if label not in df.columns:
|
| 64 |
+
if local_bar is not None:
|
| 65 |
+
pp.close()
|
| 66 |
+
raise KeyError(f"Label column '{label}' not found in DataFrame.")
|
| 67 |
+
df_in = df.dropna(subset=[label]).reset_index(drop=True)
|
| 68 |
+
X = df_in
|
| 69 |
+
y_raw = X[label].to_numpy()
|
| 70 |
+
pp and pp.tick_abs("clean", 1.0, extra={"N": len(X)})
|
| 71 |
+
pp and pp.end("clean")
|
| 72 |
+
|
| 73 |
+
# ---- task/model selection ----
|
| 74 |
+
task_applied = _infer_task(y_raw, task)
|
| 75 |
+
if model is None:
|
| 76 |
+
if task_applied == "classification":
|
| 77 |
+
model = LogisticRegression(solver="saga", n_jobs=-1)
|
| 78 |
+
else:
|
| 79 |
+
model = SGDRegressor()
|
| 80 |
+
|
| 81 |
+
# y encoding for classification (Cleanlab expects numeric labels)
|
| 82 |
+
if task_applied == "classification":
|
| 83 |
+
classes, y = np.unique(y_raw, return_inverse=True)
|
| 84 |
+
else:
|
| 85 |
+
classes = None
|
| 86 |
+
y = y_raw.astype(np.float64, copy=False)
|
| 87 |
+
|
| 88 |
+
# ---- cleanlab: find label issues (with auto-dense fallback for tiny/sparse edge cases) ----
|
| 89 |
+
pp and pp.start("cleanlab", extra={"task": task_applied})
|
| 90 |
+
used_dense_fallback = False
|
| 91 |
+
try:
|
| 92 |
+
if task_applied == "classification":
|
| 93 |
+
from cleanlab.classification import CleanLearning as _CL
|
| 94 |
+
else:
|
| 95 |
+
from cleanlab.regression.learn import CleanLearning as _CL
|
| 96 |
+
cl = _CL(model)
|
| 97 |
+
issues = cl.find_label_issues(X, y)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
if issparse(X):
|
| 100 |
+
# retry dense (cleanlab/small-N often prefers dense)
|
| 101 |
+
Xd = _ensure_dense32(X)
|
| 102 |
+
try:
|
| 103 |
+
cl = _CL(model)
|
| 104 |
+
issues = cl.find_label_issues(Xd, y)
|
| 105 |
+
used_dense_fallback = True
|
| 106 |
+
except Exception as e2:
|
| 107 |
+
if local_bar is not None:
|
| 108 |
+
pp.close()
|
| 109 |
+
raise RuntimeError(f"Cleanlab failed on sparse and dense features: {e2}") from e
|
| 110 |
+
else:
|
| 111 |
+
if local_bar is not None:
|
| 112 |
+
pp.close()
|
| 113 |
+
raise
|
| 114 |
+
|
| 115 |
+
# Parse outputs robustly
|
| 116 |
+
if isinstance(issues, pd.DataFrame):
|
| 117 |
+
is_issue = issues.get("is_label_issue", None)
|
| 118 |
+
label_quality = issues.get("label_quality", None)
|
| 119 |
+
else:
|
| 120 |
+
is_issue = None
|
| 121 |
+
label_quality = None
|
| 122 |
+
|
| 123 |
+
n = len(issues) if hasattr(issues, "__len__") else len(y)
|
| 124 |
+
n_issues = int(is_issue.sum()) if isinstance(is_issue, (pd.Series, np.ndarray)) else 0
|
| 125 |
+
pct = round((n_issues / n) * 100.0, 3) if n else 0.0
|
| 126 |
+
avg_quality = float(np.nanmean(label_quality.values)) if isinstance(label_quality, pd.Series) else float("nan")
|
| 127 |
+
|
| 128 |
+
pp and pp.tick_abs("cleanlab", 1.0, extra={"issues": n_issues})
|
| 129 |
+
pp and pp.end("cleanlab", extra={"issues": n_issues})
|
| 130 |
+
|
| 131 |
+
# ---- finalize: optionally drop issue rows ----
|
| 132 |
+
pp and pp.start("finalize")
|
| 133 |
+
df_out = df_in
|
| 134 |
+
if isinstance(is_issue, (pd.Series, np.ndarray)):
|
| 135 |
+
mask_keep = ~(is_issue.astype(bool).values)
|
| 136 |
+
df_out = df_in.loc[mask_keep].copy()
|
| 137 |
+
|
| 138 |
+
stats: Dict[str, Any] = {
|
| 139 |
+
"n_rows_before_cleanlab": int(len(df_in)),
|
| 140 |
+
"n_label_issues": int(n_issues),
|
| 141 |
+
"pct_label_issues": float(pct),
|
| 142 |
+
"avg_label_quality": float(avg_quality),
|
| 143 |
+
"n_rows_after_cleanlab": int(len(df_out)),
|
| 144 |
+
"task_applied": task_applied,
|
| 145 |
+
"model_name": type(model).__name__,
|
| 146 |
+
"used_dense_fallback": bool(used_dense_fallback),
|
| 147 |
+
"warnings": warnings,
|
| 148 |
+
}
|
| 149 |
+
pp and pp.tick_abs("finalize", 1.0)
|
| 150 |
+
pp and pp.end("finalize")
|
| 151 |
+
|
| 152 |
+
if verbose and warnings:
|
| 153 |
+
print("\n".join(warnings))
|
| 154 |
+
if local_bar is not None:
|
| 155 |
+
pp.close()
|
| 156 |
+
|
| 157 |
+
return df_out, stats
|
pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Callable, Optional, Any, Dict, Tuple, Sequence, List
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
from typing import Optional, Callable, Tuple, Dict, Any
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
def make_step(func: Callable, *, name: Optional[str] = None, use_original_df: bool = False):
|
| 9 |
+
"""
|
| 10 |
+
Wrap `func` into a pipeline step builder.
|
| 11 |
+
Assumes: func(...) -> (output_df, stats: dict)
|
| 12 |
+
|
| 13 |
+
If the function has a `df` parameter, the step will:
|
| 14 |
+
- by default use the previous step's output as df
|
| 15 |
+
- if `use_original_df=True`, use the original df passed to `run_pipeline`
|
| 16 |
+
- if you bind `df=` at build time, that takes precedence over both
|
| 17 |
+
"""
|
| 18 |
+
sig = inspect.signature(func)
|
| 19 |
+
step_name = name or func.__name__
|
| 20 |
+
|
| 21 |
+
def builder(**params):
|
| 22 |
+
sig.bind_partial(**params) # validate early
|
| 23 |
+
|
| 24 |
+
def _run(prev_df: pd.DataFrame, orig_df: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, Any], str]:
|
| 25 |
+
call = dict(params)
|
| 26 |
+
|
| 27 |
+
if "df" in sig.parameters:
|
| 28 |
+
# explicit df bound at build time wins
|
| 29 |
+
if "df" not in call:
|
| 30 |
+
call["df"] = orig_df if use_original_df and orig_df is not None else prev_df
|
| 31 |
+
|
| 32 |
+
out_df, stats = func(**call)
|
| 33 |
+
if not isinstance(stats, dict):
|
| 34 |
+
raise TypeError(f"{step_name}: expected stats to be dict, got {type(stats)}")
|
| 35 |
+
|
| 36 |
+
# for logging: the input df actually used by the function (if any), else prev_df
|
| 37 |
+
input_df = call.get("df", prev_df)
|
| 38 |
+
return input_df, out_df, stats, step_name
|
| 39 |
+
|
| 40 |
+
_run.__name__ = step_name
|
| 41 |
+
_run.__doc__ = func.__doc__
|
| 42 |
+
_run.__signature__ = sig
|
| 43 |
+
return _run
|
| 44 |
+
|
| 45 |
+
builder.__name__ = f"{step_name}_step"
|
| 46 |
+
builder.__doc__ = f"Step builder for `{func.__name__}`.\n\n" + (func.__doc__ or "")
|
| 47 |
+
builder.__signature__ = sig
|
| 48 |
+
return builder
|
| 49 |
+
|
| 50 |
+
def run_pipeline(steps: Sequence[Callable], df: pd.DataFrame):
|
| 51 |
+
"""
|
| 52 |
+
Calls each step as step(prev_df, orig_df) and chains outputs.
|
| 53 |
+
Returns final_df, logs.
|
| 54 |
+
"""
|
| 55 |
+
orig_df = df
|
| 56 |
+
prev_df = df
|
| 57 |
+
logs: List[Dict[str, Any]] = []
|
| 58 |
+
for step in steps:
|
| 59 |
+
in_df, out_df, stats, name = step(prev_df, orig_df)
|
| 60 |
+
logs.append({"step": name, **stats})
|
| 61 |
+
prev_df = out_df if out_df is not None else prev_df
|
| 62 |
+
return prev_df, logs
|
| 63 |
+
|
pipeline/utils_cool.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from time import perf_counter
|
| 6 |
+
from typing import Any, Dict, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from scipy.sparse import issparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class PhaseProgress:
|
| 16 |
+
bar: "tqdm"
|
| 17 |
+
weights: Dict[str, float]
|
| 18 |
+
total: int = 100
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
self._norm = sum(self.weights.values()) or 1.0
|
| 22 |
+
self._done = 0.0
|
| 23 |
+
self._phase = None
|
| 24 |
+
self._phase_t0 = None
|
| 25 |
+
# for smooth updates
|
| 26 |
+
if not hasattr(self.bar, "_last_val"):
|
| 27 |
+
self.bar._last_val = 0
|
| 28 |
+
|
| 29 |
+
def start(self, phase: str, extra: Optional[Dict] = None):
|
| 30 |
+
self._phase = phase
|
| 31 |
+
self._phase_t0 = perf_counter()
|
| 32 |
+
self.bar.set_description_str(phase)
|
| 33 |
+
if extra:
|
| 34 |
+
self.bar.set_postfix(extra, refresh=False)
|
| 35 |
+
|
| 36 |
+
def tick_abs(self, phase: str, p01: float, extra: Optional[Dict] = None):
|
| 37 |
+
"""Update absolute progress based on within-phase progress p01 ∈ [0,1]."""
|
| 38 |
+
p01 = max(0.0, min(1.0, float(p01)))
|
| 39 |
+
w = self.weights.get(phase, 0.0) / self._norm
|
| 40 |
+
target = int(round(self.total * (self._done + w * p01)))
|
| 41 |
+
delta = target - self.bar._last_val
|
| 42 |
+
if delta > 0:
|
| 43 |
+
self.bar.update(delta)
|
| 44 |
+
self.bar._last_val = target
|
| 45 |
+
self.bar.set_description_str(f"{phase} {int(100*p01)}%")
|
| 46 |
+
if extra:
|
| 47 |
+
self.bar.set_postfix(extra, refresh=False)
|
| 48 |
+
|
| 49 |
+
def end(self, phase: str, extra: Optional[Dict] = None):
|
| 50 |
+
w = self.weights.get(phase, 0.0) / self._norm
|
| 51 |
+
self._done += w
|
| 52 |
+
elapsed_ms = (perf_counter() - (self._phase_t0 or perf_counter())) * 1000
|
| 53 |
+
post = dict(extra or {})
|
| 54 |
+
post["t"] = f"{elapsed_ms:.0f}ms"
|
| 55 |
+
self.bar.set_postfix(post, refresh=False)
|
| 56 |
+
|
| 57 |
+
def close(self):
|
| 58 |
+
try:
|
| 59 |
+
if self.bar._last_val < self.total:
|
| 60 |
+
self.bar.update(self.total - self.bar._last_val)
|
| 61 |
+
finally:
|
| 62 |
+
self.bar.close()
|
| 63 |
+
|
| 64 |
+
def choose_k(N: int, k_min: int = 5, k_max: int = 50) -> int:
|
| 65 |
+
"""sqrt(N) clipped to [k_min, k_max] and ≤ N-1."""
|
| 66 |
+
if N <= 1:
|
| 67 |
+
return 1
|
| 68 |
+
k = int(math.sqrt(N))
|
| 69 |
+
k = max(k_min, min(k, k_max))
|
| 70 |
+
return min(k, N - 1)
|
| 71 |
+
|
| 72 |
+
def _ensure_dense32(X) -> np.ndarray:
|
| 73 |
+
"""Convert to contiguous float32 ndarray (densify only if needed)."""
|
| 74 |
+
if issparse(X):
|
| 75 |
+
X = X.toarray()
|
| 76 |
+
return np.asarray(X, dtype=np.float32, order="C")
|
| 77 |
+
|
| 78 |
+
def decide_task_and_model(
|
| 79 |
+
y: np.ndarray,
|
| 80 |
+
series: pd.Series,
|
| 81 |
+
*,
|
| 82 |
+
is_categorical: bool = False,
|
| 83 |
+
few_class_floor: int = 20,
|
| 84 |
+
few_class_frac: float = 0.05,
|
| 85 |
+
):
|
| 86 |
+
N = len(y)
|
| 87 |
+
|
| 88 |
+
# dtype checks
|
| 89 |
+
is_bool = pd.api.types.is_bool_dtype(series)
|
| 90 |
+
is_numeric = pd.api.types.is_numeric_dtype(series)
|
| 91 |
+
|
| 92 |
+
# unique values (ignore NaNs)
|
| 93 |
+
y_nonnull = y[~pd.isnull(y)]
|
| 94 |
+
n_unique = len(pd.unique(y_nonnull))
|
| 95 |
+
|
| 96 |
+
# numeric-but-few-classes heuristic
|
| 97 |
+
few_classes_threshold = max(few_class_floor, int(np.ceil(few_class_frac * max(N, 1))))
|
| 98 |
+
numeric_few_classes = is_numeric and (n_unique <= few_classes_threshold)
|
| 99 |
+
|
| 100 |
+
use_classification = (
|
| 101 |
+
is_categorical
|
| 102 |
+
or is_bool
|
| 103 |
+
or (not is_numeric)
|
| 104 |
+
or numeric_few_classes
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if use_classification:
|
| 108 |
+
return "classification"
|
| 109 |
+
else:
|
| 110 |
+
return "regression"
|
| 111 |
+
|
| 112 |
+
def _infer_task(y: np.ndarray, task: Optional[str]) -> str:
|
| 113 |
+
"""Decide task if not provided: numeric with many uniques -> regression, else classification."""
|
| 114 |
+
if task in {"classification", "regression"}:
|
| 115 |
+
return task
|
| 116 |
+
|
| 117 |
+
if np.issubdtype(y.dtype, np.number):
|
| 118 |
+
nunq = len(np.unique(y[~pd.isna(y)]))
|
| 119 |
+
is_categorical = nunq <= max(2, int(0.02 * max(1, len(y))))
|
| 120 |
+
else:
|
| 121 |
+
is_categorical = True
|
| 122 |
+
return "classification" if is_categorical else "regression"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# --------- DataFrame payload helpers (for tool IO) ---------
|
| 126 |
+
|
| 127 |
+
def df_to_payload(df: pd.DataFrame) -> Dict[str, Any]:
|
| 128 |
+
return {"orient": "split", "data": df.to_dict(orient="split")}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def df_from_payload(p: Dict[str, Any]) -> pd.DataFrame:
|
| 132 |
+
d = p["data"]
|
| 133 |
+
return pd.DataFrame(d["data"], columns=d["columns"])
|
| 134 |
+
|
| 135 |
+
# --------- Light heuristics for task/label guess ---------
|
| 136 |
+
|
| 137 |
+
def guess_task_and_label(df: pd.DataFrame) -> Dict[str, Any]:
|
| 138 |
+
cols = list(df.columns)
|
| 139 |
+
label_candidates = [c for c in cols if c.lower() in {"label","target","y","class","outcome"}]
|
| 140 |
+
label = label_candidates[0] if label_candidates else None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
task = None
|
| 144 |
+
if label and (pd.api.types.is_integer_dtype(df[label]) or pd.api.types.is_bool_dtype(df[label])):
|
| 145 |
+
nuniq = df[label].nunique(dropna=True)
|
| 146 |
+
task = "classification" if nuniq <= max(20, int(0.05*len(df))) else "regression"
|
| 147 |
+
elif label and pd.api.types.is_float_dtype(df[label]):
|
| 148 |
+
task = "regression"
|
| 149 |
+
else:
|
| 150 |
+
task = "unsupervised"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
issues = []
|
| 154 |
+
if label and df[label].isna().any():
|
| 155 |
+
issues.append(f"Missing values in label `{label}`")
|
| 156 |
+
if label and df[label].nunique() == 1:
|
| 157 |
+
issues.append(f"Label `{label}` has a single class")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"columns": cols,
|
| 162 |
+
"dtypes": {c: str(df[c].dtype) for c in cols},
|
| 163 |
+
"label_guess": label,
|
| 164 |
+
"task_guess": task,
|
| 165 |
+
"issues": issues,
|
| 166 |
+
"shape": df.shape,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# --------- Signature extraction for asking params ---------
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_signature_dict(fn) -> Dict[str, Any]:
|
| 173 |
+
sig = inspect.signature(fn)
|
| 174 |
+
doc = (fn.__doc__ or "").strip()
|
| 175 |
+
params = []
|
| 176 |
+
for p in sig.parameters.values():
|
| 177 |
+
if p.name == "df":
|
| 178 |
+
continue
|
| 179 |
+
default = None if (p.default is inspect._empty) else p.default
|
| 180 |
+
annotation = None if (p.annotation is inspect._empty) else str(p.annotation)
|
| 181 |
+
params.append({"name": p.name, "default": default, "annotation": annotation, "kind": str(p.kind)})
|
| 182 |
+
return {"params": params, "doc": doc}
|
| 183 |
+
|
| 184 |
+
# --------- Parse free-text confirmation like "Run dedup threshold=0.93 metric=cosine" ---------
|
| 185 |
+
STEP_ALIASES = {
|
| 186 |
+
"dedup": {"dedup","de-dup","duplicates","near-dup"},
|
| 187 |
+
"featurize": {"featurize","features","featureize","engineering"},
|
| 188 |
+
"find_label_issues": {"find_label_issues","label issues","cleanlab","label noise"},
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def parse_user_choice(text: str) -> Tuple[Optional[str], Dict[str, Any]]:
|
| 193 |
+
t = text.lower()
|
| 194 |
+
chosen = None
|
| 195 |
+
for step, aliases in STEP_ALIASES.items():
|
| 196 |
+
if any(a in t for a in aliases):
|
| 197 |
+
chosen = step
|
| 198 |
+
break
|
| 199 |
+
|
| 200 |
+
params: Dict[str, Any] = {}
|
| 201 |
+
for m in re.finditer(r"(\w+)\s*=\s*([\-\w\.]+)", text):
|
| 202 |
+
k, v = m.group(1), m.group(2)
|
| 203 |
+
if v.replace('.', '', 1).isdigit():
|
| 204 |
+
v = float(v) if '.' in v else int(v)
|
| 205 |
+
elif v.lower() in {"true","false"}:
|
| 206 |
+
v = (v.lower() == "true")
|
| 207 |
+
params[k] = v
|
| 208 |
+
return chosen, params
|
requirements.txt
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.12.15
|
| 4 |
+
aiosignal==1.4.0
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.11.0
|
| 7 |
+
async-timeout==4.0.3
|
| 8 |
+
attrs==25.3.0
|
| 9 |
+
cachetools==6.2.0
|
| 10 |
+
certifi==2025.8.3
|
| 11 |
+
charset-normalizer==3.4.3
|
| 12 |
+
cleanlab==2.7.1
|
| 13 |
+
click==8.1.8
|
| 14 |
+
contourpy==1.3.0
|
| 15 |
+
cycler==0.12.1
|
| 16 |
+
dataclasses-json==0.6.7
|
| 17 |
+
datasets==4.1.1
|
| 18 |
+
dill==0.4.0
|
| 19 |
+
dotenv==0.9.9
|
| 20 |
+
exceptiongroup==1.3.0
|
| 21 |
+
fastapi==0.118.0
|
| 22 |
+
ffmpy==0.6.1
|
| 23 |
+
filelock==3.19.1
|
| 24 |
+
filetype==1.2.0
|
| 25 |
+
fonttools==4.60.1
|
| 26 |
+
frozenlist==1.7.0
|
| 27 |
+
fsspec==2025.9.0
|
| 28 |
+
google-ai-generativelanguage==0.7.0
|
| 29 |
+
google-api-core==2.25.2
|
| 30 |
+
google-auth==2.41.1
|
| 31 |
+
googleapis-common-protos==1.70.0
|
| 32 |
+
gradio==4.44.1
|
| 33 |
+
gradio_client==1.3.0
|
| 34 |
+
grpcio==1.75.1
|
| 35 |
+
grpcio-status==1.75.1
|
| 36 |
+
h11==0.16.0
|
| 37 |
+
hf-xet==1.1.10
|
| 38 |
+
httpcore==1.0.9
|
| 39 |
+
httpx==0.28.1
|
| 40 |
+
httpx-sse==0.4.1
|
| 41 |
+
huggingface-hub==0.35.3
|
| 42 |
+
idna==3.10
|
| 43 |
+
importlib_resources==6.5.2
|
| 44 |
+
Jinja2==3.1.6
|
| 45 |
+
joblib==1.5.2
|
| 46 |
+
jsonpatch==1.33
|
| 47 |
+
jsonpointer==3.0.0
|
| 48 |
+
kiwisolver==1.4.7
|
| 49 |
+
langchain==0.3.27
|
| 50 |
+
langchain-community==0.3.30
|
| 51 |
+
langchain-core==0.3.78
|
| 52 |
+
langchain-google-genai==2.1.12
|
| 53 |
+
langchain-text-splitters==0.3.11
|
| 54 |
+
langgraph==0.6.8
|
| 55 |
+
langgraph-checkpoint==2.1.1
|
| 56 |
+
langgraph-prebuilt==0.6.4
|
| 57 |
+
langgraph-sdk==0.2.9
|
| 58 |
+
langsmith==0.4.32
|
| 59 |
+
markdown-it-py==3.0.0
|
| 60 |
+
MarkupSafe==2.1.5
|
| 61 |
+
marshmallow==3.26.1
|
| 62 |
+
matplotlib==3.9.4
|
| 63 |
+
mdurl==0.1.2
|
| 64 |
+
multidict==6.6.4
|
| 65 |
+
multiprocess==0.70.16
|
| 66 |
+
mypy_extensions==1.1.0
|
| 67 |
+
narwhals==2.6.0
|
| 68 |
+
numpy==1.26.4
|
| 69 |
+
orjson==3.11.3
|
| 70 |
+
ormsgpack==1.10.0
|
| 71 |
+
packaging==25.0
|
| 72 |
+
pandas==2.3.3
|
| 73 |
+
pandoc==2.4
|
| 74 |
+
pillow==10.4.0
|
| 75 |
+
plotly==6.3.1
|
| 76 |
+
plumbum==1.9.0
|
| 77 |
+
ply==3.11
|
| 78 |
+
propcache==0.4.0
|
| 79 |
+
proto-plus==1.26.1
|
| 80 |
+
protobuf==6.32.1
|
| 81 |
+
pyarrow==21.0.0
|
| 82 |
+
pyasn1==0.6.1
|
| 83 |
+
pyasn1_modules==0.4.2
|
| 84 |
+
pydantic==2.10.6
|
| 85 |
+
pydantic-settings==2.11.0
|
| 86 |
+
pydantic_core==2.27.2
|
| 87 |
+
pydub==0.25.1
|
| 88 |
+
Pygments==2.19.2
|
| 89 |
+
pypandoc==1.15
|
| 90 |
+
pyparsing==3.2.5
|
| 91 |
+
python-dateutil==2.9.0.post0
|
| 92 |
+
python-dotenv==1.1.1
|
| 93 |
+
python-multipart==0.0.20
|
| 94 |
+
pytz==2025.2
|
| 95 |
+
PyYAML==6.0.3
|
| 96 |
+
reportlab==4.4.4
|
| 97 |
+
requests==2.32.5
|
| 98 |
+
requests-toolbelt==1.0.0
|
| 99 |
+
rich==14.1.0
|
| 100 |
+
rsa==4.9.1
|
| 101 |
+
ruff==0.13.3
|
| 102 |
+
scikit-learn==1.6.1
|
| 103 |
+
scipy==1.13.1
|
| 104 |
+
semantic-version==2.10.0
|
| 105 |
+
shellingham==1.5.4
|
| 106 |
+
six==1.17.0
|
| 107 |
+
sniffio==1.3.1
|
| 108 |
+
SQLAlchemy==2.0.43
|
| 109 |
+
starlette==0.48.0
|
| 110 |
+
tenacity==9.1.2
|
| 111 |
+
termcolor==3.1.0
|
| 112 |
+
threadpoolctl==3.6.0
|
| 113 |
+
tomlkit==0.12.0
|
| 114 |
+
tqdm==4.67.1
|
| 115 |
+
typer==0.19.2
|
| 116 |
+
typing-inspect==0.9.0
|
| 117 |
+
typing-inspection==0.4.2
|
| 118 |
+
typing_extensions==4.15.0
|
| 119 |
+
tzdata==2025.2
|
| 120 |
+
urllib3==2.5.0
|
| 121 |
+
uvicorn==0.37.0
|
| 122 |
+
websockets==12.0
|
| 123 |
+
xxhash==3.6.0
|
| 124 |
+
yarl==1.20.1
|
| 125 |
+
zipp==3.23.0
|
| 126 |
+
zstandard==0.25.0
|