Suraj Prasai commited on
Commit
458c8e2
·
0 Parent(s):

aded initial

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
__init__.py ADDED
File without changes
__pycache__/app.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
__pycache__/test_app.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
agent/ChatbotHandler.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import pandas as pd
5
+ from langchain_core.messages import HumanMessage
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+
8
+ from agent.agent_graph import build_app
9
+ from pipeline.utils_cool import df_to_payload, parse_user_choice
10
+
11
+ from .runtime_ctx import get_df_summary
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
+ class ChatbotHandler:
16
+ def __init__(self):
17
+ self.ctx: Dict[str, Any] = {
18
+ "graph_app": None, # LangGraph app
19
+ "state": { # mirrors your prior STATE
20
+ "df_payload": None,
21
+ "results": [],
22
+ "steps_taken": 0,
23
+ "confirmed_step": None,
24
+ "confirmed_params": {},
25
+ "last_task": None,
26
+ "plan": None,
27
+ "messages": [],
28
+ "max_steps": 8,
29
+ },
30
+ }
31
+ # keep chat UI history in the component; we only need to return a reply string to it
32
+ # but Gradio Chatbot expects (history, ""), so we'll append our reply to history.
33
+ self._boot_text: Optional[str] = None # first reply after upload
34
+
35
+ # LLM for the graph
36
+ self.llm = ChatGoogleGenerativeAI(
37
+ model="gemini-2.5-flash-lite",
38
+ temperature=0,
39
+ api_key=os.getenv("GOOGLE_API_KEY"),
40
+ )
41
+
42
+ def _format_summary(self, s: Dict[str, Any]) -> str:
43
+ cols = s.get("columns") or []
44
+ dtypes = s.get("dtypes") or {}
45
+ shape = s.get("shape") or (None, None)
46
+ label_guess = s.get("label_guess") or "None"
47
+ task_guess = s.get("task_guess") or "Unknown"
48
+ issues = s.get("issues") or []
49
+
50
+ # keep it concise but helpful
51
+ dt_pairs = [f"{k}: {v}" for k, v in list(dtypes.items())[:8]]
52
+ if len(dtypes) > 8:
53
+ dt_pairs.append("…")
54
+
55
+ lines = [
56
+ "### Dataset summary",
57
+ f"- Shape: {shape[0]} rows × {shape[1]} columns",
58
+ f"- Columns: {', '.join(map(str, cols[:10]))}{'…' if len(cols) > 10 else ''}",
59
+ f"- Dtypes: {', '.join(dt_pairs)}",
60
+ f"- Label guess: {label_guess}",
61
+ f"- Task guess: {task_guess}",
62
+ ]
63
+ if issues:
64
+ lines.append(f"- Potential issues: {('; '.join(issues[:3]))}{'…' if len(issues) > 3 else ''}")
65
+ return "\n".join(lines)
66
+
67
+ # ---------------- Boot on upload: run inspect + SOTA + plan ----------------
68
+ def update_context(self, file_path: Optional[str], data_type: Optional[str], df: Optional["pd.DataFrame"]):
69
+ if df is None:
70
+ return ""
71
+
72
+ # (Re)build graph + seed state (unchanged)
73
+ self.ctx["graph_app"] = build_app(self.llm)
74
+ df_payload = df_to_payload(df)
75
+ st = self.ctx["state"]
76
+ st.update({
77
+ "df_payload": df_payload,
78
+ "results": [],
79
+ "steps_taken": 0,
80
+ "confirmed_step": None,
81
+ "confirmed_params": {},
82
+ "last_task": None,
83
+ "plan": None,
84
+ "messages": [HumanMessage(content="A new dataset was uploaded. Start the workflow.")],
85
+ "max_steps": 8,
86
+ })
87
+
88
+ final = self.ctx["graph_app"].invoke(st)
89
+ for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
90
+ st[k] = final.get(k, st.get(k))
91
+
92
+ # Build the boot text from the stored summary (authoritative + consistent)
93
+ s = get_df_summary() or {}
94
+ summary_text = self._format_summary(s)
95
+
96
+ # Decide whether to ask for confirmation
97
+ task_guess = (s.get("task_guess") or "").lower()
98
+ label_guess = s.get("label_guess")
99
+ needs_task = task_guess not in {"classification", "regression", "unsupervised"}
100
+ needs_label = (task_guess in {"classification", "regression"}) and (not label_guess)
101
+
102
+ if needs_task or needs_label:
103
+ ask = "\n\nPlease confirm the task" + (" and label column" if needs_label else "") + \
104
+ ". For example: `task=classification label=noisy_letter_grade`."
105
+ else:
106
+ ask = f"\n\nIf that looks right, say `confirm task={task_guess}" + \
107
+ (f" label={label_guess}`" if label_guess else "`") + \
108
+ " and I’ll fetch SOTA and propose a plan."
109
+
110
+ self._boot_text = summary_text + ask
111
+ return self._boot_text
112
+
113
+ # ---------------- One chat turn → graph turn ----------------
114
+ def respond(self, message: str, history: List):
115
+ if history is None:
116
+ history = []
117
+ msg = (message or "").strip()
118
+ if not msg:
119
+ return history, ""
120
+
121
+ # If we have a prepared boot reply (from upload) and the chat is empty,
122
+ # show it before processing the user's first message.
123
+ if self._boot_text and len(history) == 0:
124
+ history.append(("[system]", self._boot_text))
125
+ self._boot_text = None
126
+
127
+ # Require a booted graph
128
+ if self.ctx.get("graph_app") is None:
129
+ history.append((msg, "Please upload a dataset first."))
130
+ return history, ""
131
+
132
+ st = self.ctx["state"]
133
+
134
+ # Allow quick “run X a=b” parsing before we call the graph (same as your old handle_chat)
135
+ step, params = parse_user_choice(msg)
136
+ if step:
137
+ st["confirmed_step"] = step
138
+ st["confirmed_params"] = {**(st.get("confirmed_params") or {}), **params}
139
+
140
+ # Build this turn’s input state
141
+ messages = (st.get("messages") or []) + [HumanMessage(content=msg)]
142
+ turn_state = {
143
+ "messages": messages,
144
+ "df_payload": st.get("df_payload"),
145
+ "results": st.get("results", []),
146
+ "steps_taken": st.get("steps_taken", 0),
147
+ "max_steps": max(8, st.get("steps_taken", 0) + 4),
148
+ "confirmed_step": st.get("confirmed_step"),
149
+ "confirmed_params": st.get("confirmed_params", {}),
150
+ "last_task": st.get("last_task"),
151
+ "plan": st.get("plan"),
152
+ }
153
+
154
+ # Invoke graph for this turn
155
+ final = self.ctx["graph_app"].invoke(turn_state)
156
+
157
+ # Persist state back
158
+ for k in ["df_payload","results","steps_taken","confirmed_step","confirmed_params","last_task","plan","messages"]:
159
+ st[k] = final.get(k, turn_state.get(k, st.get(k)))
160
+
161
+ # Extract assistant text
162
+ reply = self._extract_ai_text(final.get("messages", [])) or "Done."
163
+ history.append((msg, reply))
164
+ return history, ""
165
+
166
+ # ---------------- helper: extract last AI string ----------------
167
+ def _extract_ai_text(self, messages: List[Any]) -> str:
168
+ def coerce_text(content: Any) -> str:
169
+ if content is None: return ""
170
+ if isinstance(content, str): return content
171
+ if isinstance(content, list):
172
+ parts = []
173
+ for c in content:
174
+ if isinstance(c, dict):
175
+ parts.append(str(c.get("text") or c.get("content") or c.get("data") or ""))
176
+ else:
177
+ parts.append(str(c))
178
+ return " ".join(p for p in parts if p)
179
+ return str(content)
180
+
181
+ for m in reversed(messages or []):
182
+ role = getattr(m, "type", None) or getattr(m, "role", None)
183
+ if role in ("ai", "assistant", "aimessage"):
184
+ return coerce_text(getattr(m, "content", None))
185
+ if isinstance(m, dict):
186
+ r = (m.get("role") or m.get("type") or "").lower()
187
+ if r in ("assistant", "ai", "aimessage"):
188
+ return coerce_text(m.get("content"))
189
+ return ""
agent/__init__.py ADDED
File without changes
agent/__pycache__/ChatbotHandler.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
agent/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (162 Bytes). View file
 
agent/__pycache__/agent_graph.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
agent/__pycache__/runtime_ctx.cpython-311.pyc ADDED
Binary file (9.63 kB). View file
 
agent/__pycache__/tools.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
agent/agent_graph.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent/agent_graph.py
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from typing import Any, Dict, List, Optional, TypedDict
6
+
7
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
8
+ from langgraph.graph import END, START, StateGraph
9
+ from langgraph.prebuilt import ToolNode
10
+
11
+ from .runtime_ctx import set_df_payload, set_df_summary, set_sota_bundled
12
+ from .tools import (
13
+ tool_describe_step,
14
+ tool_inspect_dataset,
15
+ tool_list_steps,
16
+ tool_list_versions,
17
+ tool_propose_plan,
18
+ tool_reset_to_version,
19
+ tool_run_step,
20
+ tool_sota_preprocessing,
21
+ )
22
+
23
+
24
+ def _to_text(content: Any, limit: int = 4000) -> str:
25
+ """Coerce any message content to a string that Gemini will accept."""
26
+ if content is None:
27
+ return ""
28
+ if isinstance(content, str):
29
+ return content
30
+ try:
31
+ s = json.dumps(content, default=str, ensure_ascii=False)
32
+ except Exception:
33
+ s = str(content)
34
+ # lightly truncate huge tool dumps
35
+ return (s[:limit] + " …") if len(s) > limit else s
36
+
37
+ def _sanitize_messages(msgs: list[Any]) -> list[Any]:
38
+ """Keep only system/human/assistant messages and ensure content is str."""
39
+ clean = []
40
+ for m in msgs or []:
41
+ # Drop raw ToolMessage or unknown roles (Gemini doesn't accept them)
42
+ role = getattr(m, "type", None) or getattr(m, "role", None) or ""
43
+ if isinstance(m, ToolMessage) or role == "tool":
44
+ # Optionally compress tool outputs into a short assistant line instead:
45
+ txt = _to_text(getattr(m, "content", None))
46
+ if txt:
47
+ clean.append(AIMessage(content=f"[Tool result] {txt}"))
48
+ continue
49
+
50
+ c = _to_text(getattr(m, "content", None))
51
+ if isinstance(m, SystemMessage):
52
+ clean.append(SystemMessage(content=c))
53
+ elif isinstance(m, HumanMessage):
54
+ clean.append(HumanMessage(content=c))
55
+ elif isinstance(m, AIMessage):
56
+ clean.append(AIMessage(content=c))
57
+ else:
58
+ # Unknown BaseMessage; best-effort map by role string
59
+ r = str(role).lower()
60
+ if r == "system":
61
+ clean.append(SystemMessage(content=c))
62
+ elif r in ("human", "user"):
63
+ clean.append(HumanMessage(content=c))
64
+ elif r in ("assistant", "ai", "aimessage"):
65
+ clean.append(AIMessage(content=c))
66
+ # else: ignore silently
67
+ return clean
68
+
69
+ TOOLS = [tool_inspect_dataset, tool_sota_preprocessing, tool_list_steps, tool_describe_step, tool_propose_plan, tool_run_step, tool_list_versions, tool_reset_to_version]
70
+
71
+ SYSTEM_PRIMER = (
72
+ "You are a data-quality assistant.\n"
73
+ "\n"
74
+ "Workflow:\n"
75
+ "1) Call inspect_dataset() to summarize columns/dtypes and GUESS task/label.\n"
76
+ " • If you are NOT SURE about the task (or the label for supervised tasks), ASK the user to confirm and END THE TURN.\n"
77
+ " • Do NOT call sota_preprocessing until the user explicitly confirms the task (and label if supervised).\n"
78
+ " Acceptable confirmations include messages like: "
79
+ " 'task=classification label=HARDSHIP_INDEX', 'Task: regression', or 'Unsupervised'.\n"
80
+ "2) After the user confirms, call sota_preprocessing(task, modality, ...) and PRESENT a brief 'SOTA Evidence' section (3–6 bullets with titles and links from the tool).\n"
81
+ "3) Call list_steps() and map SOTA insights to the available tools. Produce a plan (no execution yet); cite up to 2 SOTA sources per step.\n"
82
+ "4) Ask: 'Which step should we execute first?' Do NOT call run_step until the user explicitly picks.\n"
83
+ "5) After the user picks, call describe_step(name) and list ONLY real parameters from the tool. Ask for missing/optional params and confirm them.\n"
84
+ "6) Execute with run_step(name, params_json). Version controls inside params_json when relevant:\n"
85
+ " • source: 'current' | 'prev' | 'base' | '@-1' | '@-2' | <int>\n"
86
+ " • dry_run: true|false (preview without mutating)\n"
87
+ " • new_version: true|false (create new snapshot vs replace current)\n"
88
+ " Avoid loops: if the same step+params just ran, ask to change parameters or source.\n"
89
+ "7) Summarize results; optionally call list_versions() and offer reset_to_version(spec). If helpful, research again before proposing next steps.\n"
90
+ "\n"
91
+ "Rules:\n"
92
+ "- Return exactly one tool call at a time.\n"
93
+ "- Never call sota_preprocessing before explicit task confirmation.\n"
94
+ "- Never call run_step without an explicit user choice.\n"
95
+ "- When users ask about parameters, use describe_step (or list_steps) and answer ONLY from tool output.\n"
96
+ "- Reject parameters that are not in the tool signature.\n"
97
+ )
98
+
99
+
100
+ class AgentState(TypedDict):
101
+ messages: List[Any]
102
+ df_payload: Optional[Dict[str, Any]]
103
+ results: List[Dict[str, Any]]
104
+ steps_taken: int
105
+ max_steps: int
106
+ confirmed_step: Optional[str]
107
+ confirmed_params: Dict[str, Any]
108
+ last_task: Optional[str]
109
+ plan: Optional[Dict[str, Any]]
110
+
111
+ def make_agent_node(llm):
112
+ """LLM emits tool calls; we sanitize history and ALWAYS append an AIMessage."""
113
+ llm_with_tools = llm.bind_tools(TOOLS)
114
+
115
+ def _node(state: AgentState) -> AgentState:
116
+ d = (state.get("df_payload") or {}).get("data", {})
117
+ rows = len(d.get("data", []) or [])
118
+ cols = len(d.get("columns", []) or [])
119
+ shape_note = SystemMessage(content=f"Current dataset shape: {rows} rows × {cols} columns.")
120
+
121
+ history = _sanitize_messages(state.get("messages", []))
122
+ inputs = [SystemMessage(content=SYSTEM_PRIMER), *history, shape_note]
123
+
124
+ ai = llm_with_tools.invoke(inputs)
125
+ # guard: ensure we append an AIMessage object
126
+ if not isinstance(ai, AIMessage):
127
+ ai = AIMessage(content=_to_text(getattr(ai, "content", ai)))
128
+
129
+ state["messages"] = state["messages"] + [ai]
130
+ # debug
131
+ # print("DEBUG roles after agent:", [getattr(m, "type", None) or getattr(m, "role", None) for m in state["messages"]])
132
+ return state
133
+
134
+ return _node
135
+
136
+ def tools_exec_node():
137
+ """
138
+ Execute tools only here, after injecting df_payload into runtime context.
139
+ Also updates state with tool outputs (summary/SOTA/plan/step_result).
140
+ """
141
+ tool_node = ToolNode(TOOLS)
142
+
143
+ def _node(state: AgentState) -> AgentState:
144
+ # Inject dataset into runtime context BEFORE any tool executes
145
+ set_df_payload(state.get("df_payload"))
146
+
147
+ # If no dataset at all, be friendly and stop
148
+ if state.get("df_payload") is None:
149
+ state["messages"].append(type(state["messages"][-1])(content="I don't have a dataset yet. Please upload one."))
150
+ return state
151
+
152
+ # Hard gate: block run_step unless user confirmed a step
153
+ last = state["messages"][-1]
154
+ tool_calls = getattr(last, "tool_calls", None) or []
155
+ for c in tool_calls:
156
+ if c.get("name") == "run_step":
157
+ intended = (c.get("args") or {}).get("name")
158
+ if intended and intended != state.get("confirmed_step"):
159
+ state["messages"].append(type(last)(content="I have a plan ready. Which step should we run first?"))
160
+ return state
161
+
162
+ # Actually execute the tool(s) requested by the last assistant message
163
+ out = tool_node.invoke({"messages": state["messages"]})
164
+ # Append ONLY new ToolMessages; do NOT overwrite the conversation
165
+ new_msgs = [m for m in out["messages"] if isinstance(m, ToolMessage)]
166
+ if not new_msgs:
167
+ # fallback: if provider returned the whole list, take the tail
168
+ if len(out["messages"]) > len(state["messages"]):
169
+ new_msgs = out["messages"][len(state["messages"]):]
170
+ else:
171
+ new_msgs = out["messages"]
172
+
173
+ state["messages"] = state["messages"] + new_msgs
174
+
175
+ # Parse the most recent tool payload (dict in .content)
176
+ payload = new_msgs[-1].content if new_msgs else None
177
+ if isinstance(payload, dict):
178
+ typ = payload.get("type")
179
+ if typ == "dataset_summary":
180
+ set_df_summary(payload)
181
+ state["last_task"] = payload.get("task_guess")
182
+ elif typ == "sota":
183
+ set_sota_bundled(payload.get("bundled_results") or [])
184
+ elif typ == "plan":
185
+ state["plan"] = payload
186
+ elif typ == "step_result":
187
+ state["df_payload"] = payload["df"]
188
+ set_df_payload(state["df_payload"])
189
+ state["results"].append({"name": payload["name"], "stats": payload["stats"]})
190
+ state["steps_taken"] += 1
191
+ state["confirmed_step"] = None
192
+ state["confirmed_params"] = {}
193
+
194
+ # print("DEBUG roles after tools:", [getattr(m, "type", None) or getattr(m, "role", None) for m in state["messages"]])
195
+ return state
196
+
197
+ return _node
198
+
199
+ def should_continue(state: AgentState) -> str:
200
+ last = state["messages"][-1]
201
+ if state.get("steps_taken", 0) >= state.get("max_steps", 8):
202
+ return "end"
203
+ # Continue if the last assistant message contains tool calls
204
+ return "continue" if getattr(last, "tool_calls", None) else "end"
205
+
206
+ def build_app(llm):
207
+ g = StateGraph(AgentState)
208
+ g.add_node("agent", make_agent_node(llm))
209
+ g.add_node("tools", tools_exec_node())
210
+
211
+ g.add_edge(START, "agent")
212
+ g.add_conditional_edges("agent", should_continue, {"continue": "tools", "end": END})
213
+ g.add_edge("tools", "agent")
214
+
215
+ return g.compile()
agent/runtime_ctx.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent/runtime_ctx.py
2
+ from __future__ import annotations
3
+
4
+ from contextvars import ContextVar
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ # ---------------- Versioned dataset store ----------------
8
+ # Each new mutating step can create a new "version".
9
+ # You can address versions as: "current", "base", "prev", "@-1", "@-2", "@3" (0-based).
10
+ _VERSIONS_CV: ContextVar[Optional[List[Dict[str, Any]]]] = ContextVar("VERSIONS", default=None)
11
+ _CUR_INDEX_CV: ContextVar[int] = ContextVar("CUR_INDEX", default=-1)
12
+ _VERS_META_CV: ContextVar[Optional[List[Dict[str, Any]]]] = ContextVar("VERS_META", default=None)
13
+
14
+ # Legacy singletons (fallback across tasks)
15
+ _STORE: Dict[str, Any] = {
16
+ "versions": [], # list of df_payloads
17
+ "version_meta": [], # parallel list of metadata dicts
18
+ "cur_index": -1,
19
+ # kept for backward compat with old getters:
20
+ "df_payload": None, # alias of current
21
+ "base_df_payload": None, # alias of versions[0]
22
+ "sota_bundled": None,
23
+ "df_summary": None,
24
+ }
25
+
26
+ _SOTA_BUNDLED_CV: ContextVar[Optional[list]] = ContextVar("SOTA_BUNDLED", default=None)
27
+ _DF_SUMMARY_CV: ContextVar[Optional[Dict[str, Any]]] = ContextVar("DF_SUMMARY", default=None)
28
+
29
+
30
+ # -------- internal helpers --------
31
+ def _get_versions() -> List[Dict[str, Any]]:
32
+ return _VERSIONS_CV.get() or _STORE["versions"]
33
+
34
+ def _get_meta() -> List[Dict[str, Any]]:
35
+ return _VERS_META_CV.get() or _STORE["version_meta"]
36
+
37
+ def _set_versions(vers: List[Dict[str, Any]], meta: List[Dict[str, Any]], cur: int) -> None:
38
+ _VERSIONS_CV.set(vers)
39
+ _VERS_META_CV.set(meta)
40
+ _CUR_INDEX_CV.set(cur)
41
+ _STORE["versions"] = vers
42
+ _STORE["version_meta"] = meta
43
+ _STORE["cur_index"] = cur
44
+ # keep legacy aliases in sync
45
+ _STORE["df_payload"] = vers[cur] if (0 <= cur < len(vers)) else None
46
+ _STORE["base_df_payload"] = vers[0] if vers else None
47
+
48
+
49
+ # =========================
50
+ # Init / Set / Annotate
51
+ # =========================
52
+ def init_dataset(p: Optional[Dict[str, Any]]) -> None:
53
+ """Initialize version stack with a single BASE version."""
54
+ vers = [] if p is None else [p]
55
+ meta = [] if p is None else [dict(tag="base")]
56
+ cur = -1 if p is None else 0
57
+ _set_versions(vers, meta, cur)
58
+
59
+ def set_df_payload(p: Optional[Dict[str, Any]], *, new_version: bool = True) -> None:
60
+ """
61
+ Set CURRENT dataset.
62
+ - new_version=True: truncate any forward history and append p (like a new commit).
63
+ - new_version=False: replace the current version in place (no new snapshot).
64
+ """
65
+ vers = list(_get_versions())
66
+ meta = list(_get_meta())
67
+ cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
68
+
69
+ if cur < 0 or not vers:
70
+ # not initialized yet
71
+ init_dataset(p)
72
+ return
73
+
74
+ if new_version:
75
+ # drop any versions after current (no branching for simplicity)
76
+ vers = vers[:cur + 1]
77
+ meta = meta[:cur + 1]
78
+ vers.append(p)
79
+ meta.append({})
80
+ cur = len(vers) - 1
81
+ else:
82
+ vers[cur] = p
83
+
84
+ _set_versions(vers, meta, cur)
85
+
86
+ def annotate_current(**kv) -> None:
87
+ """Attach metadata to the current version (e.g., step/params/stats)."""
88
+ vers = list(_get_versions())
89
+ meta = list(_get_meta())
90
+ cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
91
+ if 0 <= cur < len(meta):
92
+ meta[cur] = {**meta[cur], **kv}
93
+ _set_versions(vers, meta, cur)
94
+
95
+
96
+ # =========================
97
+ # Getters / Navigation
98
+ # =========================
99
+ def _resolve_index(spec: Union[str, int, None]) -> int:
100
+ vers = _get_versions()
101
+ cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
102
+ if spec is None or spec == "current":
103
+ return cur
104
+ if spec == "base":
105
+ return 0 if vers else -1
106
+ if spec == "prev":
107
+ return max(-1, cur - 1)
108
+ if isinstance(spec, int):
109
+ idx = spec if spec >= 0 else len(vers) + spec
110
+ return idx
111
+ # strings like "@-1", "@3"
112
+ if isinstance(spec, str) and spec.startswith("@"):
113
+ try:
114
+ n = int(spec[1:])
115
+ except Exception:
116
+ return cur
117
+ idx = n if n >= 0 else len(vers) + n
118
+ return idx
119
+ return cur
120
+
121
+ def get_df_payload(version: Union[str, int, None] = None) -> Optional[Dict[str, Any]]:
122
+ """Return dataset payload for the requested version (default: current)."""
123
+ vers = _get_versions()
124
+ idx = _resolve_index(version)
125
+ if 0 <= idx < len(vers):
126
+ return vers[idx]
127
+ # legacy fallback
128
+ return _STORE["df_payload"]
129
+
130
+ def get_base_df_payload() -> Optional[Dict[str, Any]]:
131
+ return get_df_payload("base")
132
+
133
+ def get_prev_df_payload() -> Optional[Dict[str, Any]]:
134
+ return get_df_payload("prev")
135
+
136
+ def list_versions() -> Dict[str, Any]:
137
+ """Lightweight overview for debugging/UI."""
138
+ vers = _get_versions()
139
+ meta = _get_meta()
140
+ cur = _CUR_INDEX_CV.get() if _CUR_INDEX_CV.get() is not None else _STORE["cur_index"]
141
+ return {
142
+ "count": len(vers),
143
+ "current_index": cur,
144
+ "has_base": bool(vers),
145
+ "meta": meta, # [{tag:..., step:..., params:...}, ...]
146
+ }
147
+
148
+ def reset_current_to(version: Union[str, int]) -> None:
149
+ """Move the current pointer to a prior version (no deletion)."""
150
+ vers = _get_versions()
151
+ meta = _get_meta()
152
+ idx = _resolve_index(version)
153
+ if 0 <= idx < len(vers):
154
+ _set_versions(vers, meta, idx)
155
+
156
+ def reset_current_to_base() -> None:
157
+ reset_current_to("base")
158
+
159
+
160
+ # =========================
161
+ # SOTA / Summary passthrough
162
+ # =========================
163
+ def set_sota_bundled(b: Optional[list]) -> None:
164
+ _SOTA_BUNDLED_CV.set(b)
165
+ _STORE["sota_bundled"] = b
166
+
167
+ def get_sota_bundled() -> Optional[list]:
168
+ return _SOTA_BUNDLED_CV.get() or _STORE["sota_bundled"]
169
+
170
+ def set_df_summary(s: Optional[Dict[str, Any]]) -> None:
171
+ _DF_SUMMARY_CV.set(s)
172
+ _STORE["df_summary"] = s
173
+
174
+ def get_df_summary() -> Optional[Dict[str, Any]]:
175
+ return _DF_SUMMARY_CV.get() or _STORE["df_summary"]
agent/simple_chat.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ from langchain_google_genai import ChatGoogleGenerativeAI
4
+ from langchain_core.messages import HumanMessage
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+
10
+ def simple_chat(
11
+ question: str,
12
+ context: str,
13
+ image_paths: Optional[List[str]] = None
14
+ ) -> str:
15
+ """
16
+ Simple chat function that answers questions based on context and optional images.
17
+
18
+ Args:
19
+ question: User's question
20
+ context: Context information (e.g., dataset summary, analysis results)
21
+ image_paths: Optional list of image file paths to include
22
+
23
+ Returns:
24
+ AI response as a string
25
+ """
26
+ try:
27
+ # Initialize LLM
28
+ llm = ChatGoogleGenerativeAI(
29
+ model="gemini-2.0-flash-exp",
30
+ temperature=0,
31
+ api_key=os.getenv("GOOGLE_API_KEY"),
32
+ )
33
+
34
+ # Build the prompt
35
+ prompt = f"""You are a helpful data analysis assistant.
36
+
37
+ Context:
38
+ {context}
39
+
40
+ User Question: {question}
41
+
42
+ Please provide a clear, concise answer based on the context provided."""
43
+
44
+ # Handle images if provided
45
+ if image_paths:
46
+ content = [{"type": "text", "text": prompt}]
47
+
48
+ for img_path in image_paths:
49
+ if os.path.exists(img_path):
50
+ import base64
51
+ with open(img_path, "rb") as f:
52
+ img_data = base64.b64encode(f.read()).decode()
53
+
54
+ # Determine image type
55
+ ext = os.path.splitext(img_path)[1].lower()
56
+ mime_type = {
57
+ '.png': 'image/png',
58
+ '.jpg': 'image/jpeg',
59
+ '.jpeg': 'image/jpeg',
60
+ '.gif': 'image/gif',
61
+ '.webp': 'image/webp'
62
+ }.get(ext, 'image/png')
63
+
64
+ content.append({
65
+ "type": "image_url",
66
+ "image_url": f"data:{mime_type};base64,{img_data}"
67
+ })
68
+
69
+ message = HumanMessage(content=content)
70
+ else:
71
+ message = HumanMessage(content=prompt)
72
+
73
+ # Get response
74
+ response = llm.invoke([message])
75
+
76
+ # Extract text from response
77
+ if hasattr(response, 'content'):
78
+ return str(response.content)
79
+ return str(response)
80
+
81
+ except Exception as e:
82
+ return f"Error: {str(e)}"
83
+
84
+
85
+ # Example usage:
86
+ if __name__ == "__main__":
87
+ # Simple text-only example
88
+ context = """
89
+ Dataset: Customer Sales Data
90
+ - 1000 rows, 15 columns
91
+ - Label: purchase_made (binary)
92
+ - Task: Classification
93
+ - Missing values: 5% in age column
94
+ """
95
+
96
+ question = "What's the main task for this dataset?"
97
+ response = simple_chat(question, context)
98
+ print(response)
99
+
100
+ # With images
101
+ question2 = "What do you see in the visualization?"
102
+ response2 = simple_chat(question2, context, image_paths=["/path/to/plot.png"])
103
+ print(response2)
agent/tools.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from langchain.tools import tool
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+
9
+ from pipeline.deduplication import find_near_duplicates
10
+ from pipeline.featurizer import custom_featurizer
11
+ from pipeline.issues import find_issues
12
+ from pipeline.utils_cool import (
13
+ df_from_payload,
14
+ df_to_payload,
15
+ get_signature_dict,
16
+ guess_task_and_label,
17
+ )
18
+
19
+ from .runtime_ctx import (
20
+ get_df_payload, # now supports version spec (None|'current'|'prev'|'base'|'@-1'|int)
21
+ get_df_summary,
22
+ get_sota_bundled,
23
+ set_df_payload, # commit new dataset version (or replace)
24
+ set_df_summary,
25
+ set_sota_bundled,
26
+ )
27
+ from .runtime_ctx import (
28
+ list_versions as _list_versions_state,
29
+ )
30
+ from .runtime_ctx import (
31
+ reset_current_to as _reset_current_to,
32
+ )
33
+
34
+ # Registry of runnable steps (names used by the agent/UI)
35
+ STEP_FUNCS = {
36
+ "dedup": find_near_duplicates,
37
+ "featurize": custom_featurizer,
38
+ "find_label_issues": find_issues,
39
+ }
40
+
41
+
42
+ @tool("inspect_dataset", return_direct=True)
43
+ def tool_inspect_dataset() -> Dict[str, Any]:
44
+ """
45
+ Summarize the CURRENT dataset (no arguments required).
46
+
47
+ Behavior:
48
+ • Reads the dataset from the runtime context (set by the graph).
49
+ • Returns a compact summary of columns, dtypes, shape, and a guessed label/task.
50
+
51
+ Returns:
52
+ {
53
+ "type": "dataset_summary",
54
+ "columns": [...],
55
+ "dtypes": {col: dtype, ...},
56
+ "shape": (rows, cols),
57
+ "label_guess": "<name or None>",
58
+ "task_guess": "classification|regression|unsupervised",
59
+ "issues": [ ... ] # e.g., missing labels, single-class, etc.
60
+ }
61
+ """
62
+ df_payload = get_df_payload() # default: current version
63
+ if df_payload is None:
64
+ raise RuntimeError("inspect_dataset: no dataset available in runtime context.")
65
+ df = df_from_payload(df_payload)
66
+ summary = guess_task_and_label(df)
67
+ # keep context fresh for downstream tools
68
+ set_df_summary(summary)
69
+ return {"type": "dataset_summary", **summary}
70
+
71
+
72
+ @tool("sota_preprocessing", return_direct=True)
73
+ def tool_sota_preprocessing(
74
+ task: Optional[str] = None,
75
+ modality: Optional[str] = None,
76
+ domain: Optional[str] = None,
77
+ target: Optional[str] = None,
78
+ ) -> Dict[str, Any]:
79
+ """
80
+ Search state-of-the-art preprocessing best practices (modality-aware).
81
+
82
+ Args:
83
+ task: e.g., "classification", "regression", "segmentation", "NER", "ASR", "forecasting".
84
+ If omitted, inferred from the dataset summary if available.
85
+ modality: one of {"tabular","text","image","audio","video","time_series","graph","multimodal"}.
86
+ domain: optional domain context (e.g., "clinical", "finance").
87
+ target: optional target structure (e.g., "segmentation masks", "bounding boxes").
88
+
89
+ Returns:
90
+ {
91
+ "type": "sota",
92
+ "task": ...,
93
+ "modality": ...,
94
+ "domain": ...,
95
+ "target": ...,
96
+ "queries": [...],
97
+ "bundled_results": [{ "query": q, "results": <tavily-results> }, ...],
98
+ "results": <first tavily-results batch>
99
+ }
100
+ """
101
+ df_summary = get_df_summary() or {}
102
+ if not task:
103
+ task = df_summary.get("task_guess") or "classification"
104
+
105
+ yr = "2024 2025"
106
+ m = (modality or "").lower().strip()
107
+
108
+ modality_terms = {
109
+ "tabular": ["imputation", "encoding", "scaling", "outliers", "leakage prevention"],
110
+ "text": ["tokenization", "normalization", "subword", "BPE", "SentencePiece", "stopwords", "lemmatization", "augmentation"],
111
+ "image": ["normalization", "resizing", "color space", "augmentation", "RandAugment", "MixUp", "CutMix"],
112
+ "audio": ["resampling", "log-mel spectrogram", "MFCC", "pre-emphasis", "SpecAugment", "denoising"],
113
+ "time_series": ["resampling", "windowing", "detrending", "imputation", "outlier detection", "scaling"],
114
+ "video": ["frame sampling", "temporal augmentation", "clip normalization", "optical flow"],
115
+ "graph": ["feature normalization", "self-loops", "adjacency normalization", "sparsification"],
116
+ "multimodal": ["alignment", "synchronization", "fusion", "tokenization"],
117
+ }
118
+ m_terms = modality_terms.get(m, [])
119
+
120
+ # Build candidate queries
121
+ queries: List[str] = []
122
+ queries.append(f"state of the art preprocessing {task} {yr}")
123
+ queries.append(f"best practices data preprocessing {task} {yr}")
124
+ if m:
125
+ queries.append(f"{m} {task} preprocessing best practices {yr}")
126
+ if domain:
127
+ queries.append(f"{domain} {m or ''} {task} preprocessing best practices {yr}".strip())
128
+ if target:
129
+ queries.append(f"{m or ''} {task} {target} preprocessing pipeline {yr}".strip())
130
+ if m_terms:
131
+ queries.append(f"{m} {task} preprocessing {' '.join(m_terms)} {yr}")
132
+
133
+ # Deduplicate, preserve order
134
+ seen = set()
135
+ queries = [q for q in (q.strip() for q in queries) if q and (q not in seen and not seen.add(q))]
136
+
137
+ tavily = TavilySearchResults(k=6)
138
+ bundled: List[Dict[str, Any]] = [{"query": q, "results": tavily.invoke({"query": q})} for q in queries]
139
+ flat_first = bundled[0]["results"] if (bundled and "results" in bundled[0]) else []
140
+
141
+ # persist for planning
142
+ set_sota_bundled(bundled)
143
+
144
+ return {
145
+ "type": "sota",
146
+ "task": task,
147
+ "modality": m or "unknown",
148
+ "domain": domain,
149
+ "target": target,
150
+ "queries": queries,
151
+ "bundled_results": bundled,
152
+ "results": flat_first,
153
+ }
154
+
155
+ @tool("describe_step", return_direct=True)
156
+ def tool_describe_step(name: str) -> Dict[str, Any]:
157
+ """
158
+ Return the exact docstring + parameter schema for a single step by name.
159
+ This prevents the model from inventing params.
160
+ """
161
+ if name not in STEP_FUNCS:
162
+ raise ValueError(f"Unknown step '{name}'. Available: {list(STEP_FUNCS)}")
163
+ fn = STEP_FUNCS[name]
164
+ sig = get_signature_dict(fn) # your util that introspects defaults/annotations
165
+ return {"type": "step_description", "name": name, **sig}
166
+
167
+ @tool("list_steps", return_direct=True)
168
+ def tool_list_steps() -> Dict[str, Any]:
169
+ """
170
+ List available pipeline steps (name, docstring, and signature).
171
+
172
+ Returns:
173
+ {
174
+ "type": "steps",
175
+ "steps": [
176
+ {
177
+ "name": "dedup" | "featurize" | "find_label_issues",
178
+ "doc": "<docstring>",
179
+ "params": [{"name": "...", "default": ..., "annotation": "...", "kind": "..."}]
180
+ }, ...
181
+ ]
182
+ }
183
+ """
184
+ return {
185
+ "type": "steps",
186
+ "steps": [{"name": n, **get_signature_dict(fn)} for n, fn in STEP_FUNCS.items()],
187
+ }
188
+
189
+
190
+ @tool("propose_plan", return_direct=True)
191
+ def tool_propose_plan(
192
+ task: Optional[str] = None,
193
+ modality: Optional[str] = None,
194
+ ) -> Dict[str, Any]:
195
+ """
196
+ Propose an ordered preprocessing plan grounded in SOTA + dataset summary.
197
+ (Planning only — does not execute steps.)
198
+ """
199
+ df_summary = get_df_summary() or {}
200
+ bundled = get_sota_bundled() or []
201
+
202
+ if not task:
203
+ task = df_summary.get("task_guess") or "classification"
204
+ label_guess = df_summary.get("label_guess")
205
+
206
+ KEYWORDS = {
207
+ "dedup": {"duplicate", "near-duplicate", "near duplicate", "dupe", "dedup", "similarity", "knn", "kNN"},
208
+ "featurize": {
209
+ "impute", "imputation", "encoding", "one-hot", "scale", "scaling",
210
+ "normalize", "normalization", "standardize", "tfidf", "tokenization",
211
+ "lemmatization", "augmentation"
212
+ },
213
+ "find_label_issues": {"label noise", "noisy labels", "cleanlab", "confident learning", "label issues", "weak labels"},
214
+ }
215
+
216
+ def _score(text: str, keys: set[str]) -> int:
217
+ t = (text or "").lower()
218
+ return sum(1 for k in keys if k in t)
219
+
220
+ hits = {"dedup": 0, "featurize": 0, "find_label_issues": 0}
221
+ evidence: Dict[str, List[Dict[str, str]]] = {"dedup": [], "featurize": [], "find_label_issues": []}
222
+
223
+ for pack in bundled:
224
+ q = pack.get("query", "")
225
+ for item in (pack.get("results") or []):
226
+ title = item.get("title", "")
227
+ content = item.get("content", "")
228
+ url = item.get("url", "")
229
+ for step, keys in KEYWORDS.items():
230
+ s = _score(f"{q} {title} {content}", keys)
231
+ if s > 0:
232
+ hits[step] += s
233
+ if len(evidence[step]) < 5:
234
+ evidence[step].append({"query": q, "title": title, "url": url})
235
+
236
+ options: List[Dict[str, Any]] = []
237
+
238
+ if hits["dedup"] > 0 or modality in {None, "tabular", "text", "image", "time_series"}:
239
+ options.append(
240
+ {
241
+ "reason": "SOTA emphasizes handling near-duplicates early" if hits["dedup"] else "Practical first step to prevent leakage/skew",
242
+ "step": "dedup",
243
+ "params": {"threshold": 0.95, "metric": "cosine"},
244
+ "evidence": evidence["dedup"][:3],
245
+ }
246
+ )
247
+
248
+ options.append(
249
+ {
250
+ "reason": "SOTA emphasizes robust imputation/encoding/scaling" if hits["featurize"] else "Prepare features based on modality",
251
+ "step": "featurize",
252
+ "params": {"nan_strategy": "impute"},
253
+ "evidence": evidence["featurize"][:3],
254
+ }
255
+ )
256
+
257
+ if (task == "classification" and label_guess) or hits["find_label_issues"] > 0:
258
+ options.append(
259
+ {
260
+ "reason": "SOTA recommends checking noisy labels" if hits["find_label_issues"] else "Check label quality before training",
261
+ "step": "find_label_issues",
262
+ "params": {"label": label_guess or "<CONFIRM>"},
263
+ "evidence": evidence["find_label_issues"][:3],
264
+ }
265
+ )
266
+
267
+ if not options:
268
+ options = [{"reason": "Generic best practice", "step": "featurize", "params": {"nan_strategy": "impute"}, "evidence": []}]
269
+
270
+ return {
271
+ "type": "plan",
272
+ "task": task,
273
+ "modality": modality,
274
+ "label_guess": label_guess,
275
+ "options": options,
276
+ "keyword_hits": hits,
277
+ }
278
+
279
+
280
+ @tool("run_step", return_direct=True)
281
+ def tool_run_step(name: str, params_json: str = "") -> Dict[str, Any]:
282
+ """
283
+ Execute a single pipeline step on the CURRENT dataset (no df argument).
284
+ Returns ONLY a compact summary; the updated df is stored in runtime context.
285
+ """
286
+ df_payload = get_df_payload()
287
+ if df_payload is None:
288
+ raise RuntimeError("run_step: no dataset available in runtime context.")
289
+ if name not in STEP_FUNCS:
290
+ raise ValueError(f"Unknown step '{name}'. Available: {list(STEP_FUNCS)}")
291
+
292
+ params = json.loads(params_json) if params_json else {}
293
+ if not isinstance(params, dict):
294
+ raise ValueError("params_json must decode to a JSON object")
295
+
296
+ df = df_from_payload(df_payload)
297
+ df_out, stats = STEP_FUNCS[name](df=df, **params)
298
+ df_next = df_out if df_out is not None else df
299
+
300
+ # ✅ update runtime dataset, but DO NOT send it back in the tool message
301
+ set_df_payload(df_to_payload(df_next))
302
+
303
+ # Build a tiny, safe summary for the model
304
+ shape_before = (len(df), len(df.columns))
305
+ shape_after = (len(df_next), len(df_next.columns))
306
+ compact_stats = {k: stats.get(k) for k in [
307
+ "n_rows_before_dedup", "n_near_dupe_pairs", "n_groups",
308
+ "n_rows_flagged_duplicates", "n_rows_after_dedup",
309
+ "metric", "threshold", "k", "total_time_sec"
310
+ ] if k in stats}
311
+
312
+ return {
313
+ "type": "step_result",
314
+ "name": name,
315
+ "params_used": params,
316
+ "shape_before": shape_before,
317
+ "shape_after": shape_after,
318
+ "stats": compact_stats, # small dict only
319
+ "note": "Dataset updated in runtime context; use list_versions/reset_to_version if needed."
320
+ }
321
+
322
+ # ---------------------------
323
+ # Optional helpers for version control from chat/agent
324
+ # ---------------------------
325
+ @tool("list_versions", return_direct=True)
326
+ def tool_list_versions() -> Dict[str, Any]:
327
+ """
328
+ Return a lightweight view of the version stack:
329
+ { count, current_index, has_base, meta: [{...}, ...] }
330
+ """
331
+ return {"type": "versions", **_list_versions_state()}
332
+
333
+
334
+ @tool("reset_to_version", return_direct=True)
335
+ def tool_reset_to_version(spec: str) -> Dict[str, Any]:
336
+ """
337
+ Move CURRENT pointer to a prior version without deleting history.
338
+ spec can be: "base" | "prev" | "@-1" | "@-2" | "3"
339
+ """
340
+ # accept int or @-k in string form
341
+ try:
342
+ if spec.isdigit():
343
+ _reset_current_to(int(spec))
344
+ else:
345
+ _reset_current_to(spec)
346
+ except Exception as e:
347
+ raise RuntimeError(f"reset_to_version: {e}")
348
+
349
+ df_payload = get_df_payload()
350
+ return {
351
+ "type": "reset",
352
+ "current": _list_versions_state(),
353
+ "df": df_payload,
354
+ }
app.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from typing import Optional, Tuple, List, Dict, Any
5
+ import io
6
+ import base64
7
+ from tqdm.auto import tqdm
8
+ from dataclasses import dataclass
9
+ import gradio as gr
10
+ import json
11
+
12
+ from pipeline.deduplication import find_near_duplicates
13
+ from pipeline.featurizer import custom_featurizer
14
+ from pipeline.issues import find_issues
15
+ from pipeline.pipeline import make_step, run_pipeline
16
+
17
+ from ecg_analyzer import ECGAnalyzer
18
+ from agent.simple_chat import simple_chat
19
+ # ============================================================================
20
+ # ANALYSIS TASK CONFIGURATION
21
+ # ============================================================================
22
+
23
+ @dataclass
24
+ class TaskConfig:
25
+ """Configuration for each analysis task"""
26
+ name: str
27
+ data_type: str
28
+ requires_params: bool
29
+ param_components: List[Dict[str, Any]]
30
+ output_tabs: List[str]
31
+
32
+ class TaskRegistry:
33
+ """Registry mapping tasks to their configurations"""
34
+
35
+ @staticmethod
36
+ def get_config(data_type: str, task_name: str) -> Optional[TaskConfig]:
37
+ """Get configuration for a specific task"""
38
+ configs = {
39
+ "EHR Data": {
40
+ "Near-Duplicate Detection": TaskConfig(
41
+ name="Near-Duplicate Detection",
42
+ data_type="EHR Data",
43
+ requires_params=True,
44
+ param_components=[
45
+ {"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
46
+ ],
47
+ output_tabs=["original", "processed", "summary"]
48
+ ),
49
+ "Find Mislabeled Data": TaskConfig(
50
+ name="Find Mislabeled Data",
51
+ data_type="EHR Data",
52
+ requires_params=True,
53
+ param_components=[
54
+ {"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
55
+ ],
56
+ output_tabs=["original", "summary"]
57
+ )
58
+ },
59
+ "ECG Data": {
60
+ "ECG Visualization": TaskConfig(
61
+ name="ECG Visualization",
62
+ data_type="ECG Data",
63
+ requires_params=True,
64
+ param_components=[
65
+ {"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_leads"},
66
+ {"type": "checkboxgroup", "label": "Visualization Types", "elem_id": "ecg_viz_types"}
67
+ ],
68
+ output_tabs=["visualization", "summary"]
69
+ ),
70
+ "Statistical Summary": TaskConfig(
71
+ name="Statistical Summary",
72
+ data_type="ECG Data",
73
+ requires_params=True,
74
+ param_components=[
75
+ {"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_stats_leads"}
76
+ ],
77
+ output_tabs=["summary", "visualization"]
78
+ )
79
+ }
80
+ }
81
+ return configs.get(data_type, {}).get(task_name)
82
+
83
+ @staticmethod
84
+ def get_tasks_for_data_type(data_type: str) -> List[str]:
85
+ """Get available tasks for a data type"""
86
+ tasks = {
87
+ "EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
88
+ "ECG Data": ["ECG Visualization", "Statistical Summary"]
89
+ }
90
+ return tasks.get(data_type, [])
91
+
92
+
93
+ # ============================================================================
94
+ # ANALYSIS EXECUTION
95
+ # ============================================================================
96
+
97
+ class AnalysisExecutor:
98
+ """Executes analysis tasks and returns results"""
99
+
100
+ @staticmethod
101
+ def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
102
+ """Execute near-duplicate detection pipeline"""
103
+ try:
104
+ if not label:
105
+ return "⚠ Label column required", {"original": df, "processed": None, "summary": None}
106
+
107
+ bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
108
+ steps = [
109
+ make_step(find_near_duplicates, name="dedup")(progress=bar),
110
+ make_step(custom_featurizer, name="featurize")(
111
+ label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
112
+ ),
113
+ make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
114
+ ]
115
+ results_df, summary_list = run_pipeline(steps, df=df)
116
+ bar.close()
117
+
118
+ return "✓ Near-duplicate detection completed", {
119
+ "original": df, "processed": results_df, "summary": summary_list
120
+ }
121
+ except Exception as e:
122
+ return f"✗ Error: {str(e)}", {"original": df, "processed": None, "summary": None}
123
+
124
+ @staticmethod
125
+ def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
126
+ """Execute mislabeled data detection"""
127
+ try:
128
+ if not label:
129
+ return "⚠ Label column required", {"original": df, "summary": None}
130
+
131
+ summary = {
132
+ "task": "Find Mislabeled Data", "label_column": label, "total_samples": len(df),
133
+ "suspicious_samples": 0, "message": "Mislabeled detection analysis completed"
134
+ }
135
+ return "✓ Mislabeled data analysis completed", {"original": df, "summary": summary}
136
+ except Exception as e:
137
+ return f"✗ Error: {str(e)}", {"original": df, "summary": None}
138
+
139
+ @staticmethod
140
+ def execute_ecg_visualization(df: pd.DataFrame, leads: List[str] = None, viz_types: List[str] = None) -> Tuple[str, Dict[str, Any]]:
141
+ """Execute ECG visualization using ECGAnalyzer"""
142
+ try:
143
+ # Detect available leads
144
+ available_leads = ECGAnalyzer.detect_leads(df)
145
+
146
+ # Use provided leads or default to all available
147
+ if not leads:
148
+ leads = available_leads if available_leads else []
149
+
150
+ if not leads:
151
+ return "⚠ No ECG leads found in data", {"visualization": None, "summary": None}
152
+
153
+ # Default visualization types
154
+ if not viz_types:
155
+ viz_types = ["Signal Waveform", "Histogram"]
156
+
157
+ # Create visualizations
158
+ viz_html = ECGAnalyzer.create_all_visualizations(df, leads, viz_types)
159
+
160
+ # Generate statistics
161
+ stats = ECGAnalyzer.generate_statistics(df, leads)
162
+
163
+ summary = {
164
+ "task": "ECG Visualization",
165
+ "samples": len(df),
166
+ "leads_analyzed": leads,
167
+ "visualizations": viz_types,
168
+ "statistics": stats
169
+ }
170
+
171
+ return "✓ ECG visualization created", {"visualization": viz_html, "summary": summary}
172
+ except Exception as e:
173
+ return f"✗ Error: {str(e)}", {"visualization": None, "summary": None}
174
+
175
+ @staticmethod
176
+ def execute_statistical_summary(df: pd.DataFrame, leads: List[str] = None) -> Tuple[str, Dict[str, Any]]:
177
+ """Execute statistical summary using ECGAnalyzer"""
178
+ try:
179
+ # Detect available leads
180
+ available_leads = ECGAnalyzer.detect_leads(df)
181
+
182
+ # Use provided leads or default to all available
183
+ if not leads:
184
+ leads = available_leads if available_leads else list(df.select_dtypes(include=[np.number]).columns)
185
+
186
+ if not leads:
187
+ return "⚠ No numeric columns found", {"summary": None, "visualization": None}
188
+
189
+ # Generate statistics
190
+ stats = ECGAnalyzer.generate_statistics(df, leads)
191
+
192
+ # Create HTML table for statistics
193
+ html_rows = []
194
+ html_rows.append("<table class='preview-table' style='margin: 20px auto; max-width: 900px;'>")
195
+ html_rows.append("<thead><tr><th>Lead</th><th>Mean</th><th>Std</th><th>Min</th><th>Q25</th><th>Median</th><th>Q75</th><th>Max</th></tr></thead>")
196
+ html_rows.append("<tbody>")
197
+
198
+ for lead, lead_stats in stats.items():
199
+ html_rows.append(f"<tr>")
200
+ html_rows.append(f"<td><strong>{lead}</strong></td>")
201
+ html_rows.append(f"<td>{lead_stats['mean']:.4f}</td>")
202
+ html_rows.append(f"<td>{lead_stats['std']:.4f}</td>")
203
+ html_rows.append(f"<td>{lead_stats['min']:.4f}</td>")
204
+ html_rows.append(f"<td>{lead_stats['q25']:.4f}</td>")
205
+ html_rows.append(f"<td>{lead_stats['median']:.4f}</td>")
206
+ html_rows.append(f"<td>{lead_stats['q75']:.4f}</td>")
207
+ html_rows.append(f"<td>{lead_stats['max']:.4f}</td>")
208
+ html_rows.append(f"</tr>")
209
+
210
+ html_rows.append("</tbody></table>")
211
+ summary_html = f"<div style='overflow-x:auto;'><h3 style='text-align:center;'>Statistical Summary</h3>{''.join(html_rows)}</div>"
212
+
213
+ summary = {
214
+ "task": "Statistical Summary",
215
+ "rows": len(df),
216
+ "leads_analyzed": leads,
217
+ "statistics": stats
218
+ }
219
+
220
+ return "✓ Statistical summary generated", {"summary": summary, "visualization": summary_html}
221
+ except Exception as e:
222
+ return f"✗ Error: {str(e)}", {"summary": None, "visualization": None}
223
+
224
+
225
+ # ============================================================================
226
+ # UI MANAGER - Handles all UI state and updates
227
+ # ============================================================================
228
+
229
+ class UIManager:
230
+ """Manages UI state and dynamic updates"""
231
+
232
+ def __init__(self):
233
+ self.current_df = None
234
+ self.current_data_type = "EHR Data"
235
+ self.chatbot_context = {}
236
+ self.command_map = {
237
+ "data_type": {"ehr": "EHR Data", "ecg": "ECG Data"},
238
+ "task": {
239
+ "deduplication": "Near-Duplicate Detection", "mislabeled": "Find Mislabeled Data",
240
+ "visualize_ecg": "ECG Visualization", "stats": "Statistical Summary"
241
+ }
242
+ }
243
+
244
+ def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
245
+ """Load CSV file"""
246
+ if file is None: return "⚠ No file uploaded", None
247
+ try:
248
+ df = pd.read_csv(file.name)
249
+ self.current_df = df
250
+ return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
251
+ except Exception as e:
252
+ return f"✗ Error: {str(e)}", None
253
+
254
+ def on_file_upload(self, file, data_type: str):
255
+ """Handle file upload - returns updates for all components"""
256
+ status, df = self.load_csv(file)
257
+ if df is None:
258
+ return (
259
+ status, gr.update(value=None), gr.update(choices=[], value=[]),
260
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
261
+ gr.update(choices=[]), gr.update(choices=[]),
262
+ gr.update(choices=[]), gr.update(choices=[], value=[]),
263
+ gr.update(choices=[]), gr.update(choices=[], value=[]),
264
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
265
+ )
266
+
267
+ self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
268
+ available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
269
+ col_choices = list(df.columns)
270
+
271
+ # Detect ECG leads if ECG data
272
+ ecg_leads = ECGAnalyzer.detect_leads(df) if data_type == "ECG Data" else []
273
+ viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
274
+
275
+ return (
276
+ status, gr.update(value=df.head(200)), gr.update(choices=available_tasks, value=[], interactive=True),
277
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
278
+ gr.update(choices=col_choices, value=None), gr.update(choices=col_choices, value=None),
279
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
280
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
281
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
282
+ )
283
+
284
+ def on_data_type_change(self, data_type: str, file):
285
+ """Handle data type change"""
286
+ self.current_data_type = data_type
287
+ if file and self.current_df is not None:
288
+ self.chatbot_context["type"] = data_type
289
+
290
+ # Update ECG lead choices if switching to ECG data
291
+ ecg_leads = ECGAnalyzer.detect_leads(self.current_df) if data_type == "ECG Data" else []
292
+ viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
293
+
294
+ return (
295
+ gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
296
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
297
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
298
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
299
+ f"Data type changed to: {data_type}"
300
+ )
301
+
302
+ return (
303
+ gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
304
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
305
+ gr.update(), gr.update(), gr.update(), gr.update(),
306
+ f"Data type changed to: {data_type}"
307
+ )
308
+
309
+ def on_tasks_change(self, selected_tasks: List[str]):
310
+ """Handle task selection change - show/hide parameter groups"""
311
+ show_ndd = "Near-Duplicate Detection" in selected_tasks
312
+ show_mislabel = "Find Mislabeled Data" in selected_tasks
313
+ show_ecg_viz = "ECG Visualization" in selected_tasks
314
+ show_ecg_stats = "Statistical Summary" in selected_tasks and self.current_data_type == "ECG Data"
315
+ return (
316
+ gr.update(visible=show_ndd),
317
+ gr.update(visible=show_mislabel),
318
+ gr.update(visible=show_ecg_viz),
319
+ gr.update(visible=show_ecg_stats)
320
+ )
321
+
322
+ def process_analysis(self, file, data_type: str, selected_tasks: List[str],
323
+ ndd_label: str, mislabel_label: str,
324
+ ecg_viz_leads: List[str], ecg_viz_types: List[str],
325
+ ecg_stats_leads: List[str]):
326
+ """Process analysis tasks based on UI inputs."""
327
+ status, df = self.load_csv(file)
328
+ if df is None:
329
+ return (status, None, None, None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
330
+
331
+ params = {
332
+ "ndd_label": ndd_label,
333
+ "mislabel_label": mislabel_label,
334
+ "ecg_viz_leads": ecg_viz_leads,
335
+ "ecg_viz_types": ecg_viz_types,
336
+ "ecg_stats_leads": ecg_stats_leads
337
+ }
338
+ return self._run_analysis(df, data_type, selected_tasks, params)
339
+
340
+ def _run_analysis(self, df: pd.DataFrame, data_type: str, selected_tasks: List[str], params: Dict[str, Any]):
341
+ """Centralized analysis executor, callable from UI or chatbot."""
342
+ if not selected_tasks:
343
+ return (
344
+ "⚠ No tasks selected", df.head(200), None, None, None,
345
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
346
+ )
347
+
348
+ all_tabs = set(); all_results = {"original": df.head(200), "processed": None, "summary": [], "visualization": ""}
349
+ status_messages = []; executor = AnalysisExecutor()
350
+
351
+ for task_name in selected_tasks:
352
+ config = TaskRegistry.get_config(data_type, task_name)
353
+ if not config:
354
+ status_messages.append(f"✗ Unknown task: {task_name}"); continue
355
+ all_tabs.update(config.output_tabs)
356
+
357
+ if task_name == "Near-Duplicate Detection":
358
+ status_msg, results = executor.execute_near_duplicate_detection(df, params.get("ndd_label"))
359
+ elif task_name == "Find Mislabeled Data":
360
+ status_msg, results = executor.execute_find_mislabeled(df, params.get("mislabel_label"))
361
+ elif task_name == "ECG Visualization":
362
+ status_msg, results = executor.execute_ecg_visualization(
363
+ df,
364
+ params.get("ecg_viz_leads"),
365
+ params.get("ecg_viz_types")
366
+ )
367
+ elif task_name == "Statistical Summary":
368
+ if data_type == "ECG Data":
369
+ status_msg, results = executor.execute_statistical_summary(df, params.get("ecg_stats_leads"))
370
+ else:
371
+ status_msg, results = executor.execute_statistical_summary(df)
372
+ else:
373
+ status_msg, results = "✗ Task not implemented", {}
374
+
375
+ status_messages.append(f"{task_name}: {status_msg}")
376
+ if results.get("processed") is not None: all_results["processed"] = results["processed"]
377
+ if results.get("visualization"): all_results["visualization"] += results["visualization"]
378
+ if results.get("summary") is not None: all_results["summary"].append({"task": task_name, "data": results["summary"]})
379
+
380
+ self.chatbot_context["summary"] = all_results["summary"] or None
381
+ self.chatbot_context["visualization"] = all_results["visualization"] or None
382
+
383
+ return (
384
+ "\n".join(status_messages), all_results["original"], all_results["processed"],
385
+ all_results["summary"] or None, all_results["visualization"] or None,
386
+ gr.update(visible="original" in all_tabs), gr.update(visible="processed" in all_tabs),
387
+ gr.update(visible="summary" in all_tabs), gr.update(visible="visualization" in all_tabs)
388
+ )
389
+
390
+ def chatbot_respond(self, message: str, history: List):
391
+ """Handle chatbot messages, parsing for commands or responding to queries."""
392
+ history = history or []; df = self.chatbot_context.get("df")
393
+
394
+ summary = json.dumps(self.chatbot_context.get("summary")[-1]) if self.chatbot_context.get("summary") else ""
395
+
396
+ # visualization = self.chatbot_context.get("visualization")
397
+ visualization = ''
398
+
399
+ print("history:", history)
400
+ print("# ============================================================================\n ")
401
+ print("message:", message)
402
+ print("# ============================================================================\n ")
403
+ print("summary:", summary)
404
+ print("# ============================================================================\n ")
405
+ print("visualization:", visualization)
406
+ print("# ============================================================================\n ")
407
+
408
+ ui_updates = tuple([gr.update()] * 9) # status, 4 outputs, 4 tabs
409
+
410
+ command = message
411
+ context = summary + visualization if summary or visualization else ""
412
+ response = simple_chat(command, context)
413
+
414
+ history.append((message, response))
415
+ return (history, "") + ui_updates
416
+
417
+
418
+ # ============================================================================
419
+ # GRADIO INTERFACE
420
+ # ============================================================================
421
+
422
+ def create_interface():
423
+ """Build the Gradio interface"""
424
+ ui_manager = UIManager()
425
+ custom_css = """
426
+ * { box-sizing: border-box; } html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
427
+ .gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
428
+ #app-container { height: 100vh; display: flex; flex-direction: column; padding: 0.75rem; gap: 0.75rem; }
429
+ #main-row { flex: 1; min-height: 0; display: flex; gap: 0.75rem; }
430
+ #left-panel { display: flex; flex-direction: column; height: 100%; background: #f9fafb; border-radius: 10px; padding: 0.75rem; gap: 0.5rem; }
431
+ #task-section { flex: 1; min-height: 0; overflow-y: auto; display: flex; flex-direction: column; gap: 0.5rem; }
432
+ #middle-panel, #chat-panel { display: flex; flex-direction: column; height: 100%; }
433
+ #tabs-container { flex: 1; min-height: 0; display: flex; flex-direction: column; }
434
+ #tabs-container .tabitem { flex: 1; min-height: 0; overflow: auto; }
435
+ #chat-history { flex: 1; min-height: 0; overflow-y: auto; margin-bottom: 0.5rem; }
436
+ #chat-input-row { flex-shrink: 0; display: flex; gap: 0.5rem; }
437
+ .preview-table { border-collapse: collapse; width: 100%; font-size: 0.875rem; }
438
+ .preview-table th { background-color: #3498db; color: white; padding: 8px; text-align: left; position: sticky; top: 0; }
439
+ """
440
+
441
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
442
+ with gr.Column(elem_id="app-container"):
443
+ gr.Markdown("# 🏥 Medical Data Analysis Platform")
444
+ with gr.Row():
445
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
446
+ data_type = gr.Dropdown(choices=["EHR Data", "ECG Data"], value="EHR Data", label="Data Type", scale=1)
447
+
448
+ with gr.Row(elem_id="main-row"):
449
+ with gr.Column(scale=2, elem_id="left-panel"):
450
+ with gr.Group(elem_id="task-section"):
451
+ gr.Markdown("#### Analysis Tasks")
452
+ task_selector = gr.CheckboxGroup(choices=TaskRegistry.get_tasks_for_data_type("EHR Data"), label=None)
453
+ with gr.Group(visible=False) as ndd_param_group:
454
+ gr.Markdown("**Near-Duplicate Detection Parameters**")
455
+ ndd_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
456
+ with gr.Group(visible=False) as mislabel_param_group:
457
+ gr.Markdown("**Find Mislabeled Data Parameters**")
458
+ mislabel_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
459
+ with gr.Group(visible=False) as ecg_viz_param_group:
460
+ gr.Markdown("**ECG Visualization Parameters**")
461
+ ecg_viz_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
462
+ ecg_viz_types = gr.CheckboxGroup(
463
+ choices=["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"],
464
+ label="Visualization Types",
465
+ value=["Signal Waveform", "Histogram"]
466
+ )
467
+ with gr.Group(visible=False) as ecg_stats_param_group:
468
+ gr.Markdown("**Statistical Summary Parameters**")
469
+ ecg_stats_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
470
+ process_btn = gr.Button("▶ Process", variant="primary")
471
+ status_output = gr.Textbox(label="Status", interactive=False, lines=2)
472
+
473
+ with gr.Column(scale=7, elem_id="middle-panel"):
474
+ with gr.Tabs(elem_id="tabs-container"):
475
+ with gr.TabItem("Original Data", visible=False) as tab_original:
476
+ original_df_output = gr.DataFrame(interactive=False)
477
+ with gr.TabItem("Processed Data", visible=False) as tab_processed:
478
+ processed_df_output = gr.DataFrame(interactive=False)
479
+ with gr.TabItem("Summary", visible=False) as tab_summary:
480
+ summary_output = gr.JSON()
481
+ with gr.TabItem("Visualization", visible=False) as tab_viz:
482
+ viz_output = gr.HTML()
483
+
484
+ with gr.Column(scale=3, elem_id="chat-panel"):
485
+ gr.Markdown("### 💬 AI Assistant")
486
+ chatbot = gr.Chatbot(elem_id="chat-history", height="100%")
487
+ with gr.Row(elem_id="chat-input-row"):
488
+ msg_input = gr.Textbox(placeholder="Ask or send a JSON command...", scale=4, container=False)
489
+ send_btn = gr.Button("Send", scale=1)
490
+
491
+ analysis_outputs = [
492
+ status_output, original_df_output, processed_df_output, summary_output, viz_output,
493
+ tab_original, tab_processed, tab_summary, tab_viz
494
+ ]
495
+
496
+ file_input.change(
497
+ fn=ui_manager.on_file_upload, inputs=[file_input, data_type],
498
+ outputs=[status_output, original_df_output, task_selector,
499
+ ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
500
+ ndd_label_dropdown, mislabel_label_dropdown,
501
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types,
502
+ tab_original, tab_processed, tab_summary, tab_viz]
503
+ )
504
+ data_type.change(
505
+ fn=ui_manager.on_data_type_change, inputs=[data_type, file_input],
506
+ outputs=[task_selector, ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
507
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, status_output]
508
+ )
509
+ task_selector.change(
510
+ fn=ui_manager.on_tasks_change, inputs=[task_selector],
511
+ outputs=[ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group]
512
+ )
513
+ process_btn.click(
514
+ fn=ui_manager.process_analysis,
515
+ inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown,
516
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads],
517
+ outputs=analysis_outputs
518
+ )
519
+
520
+ chat_submit_args = {"fn": ui_manager.chatbot_respond, "inputs": [msg_input, chatbot], "outputs": [chatbot, msg_input] + analysis_outputs}
521
+ send_btn.click(**chat_submit_args)
522
+ msg_input.submit(**chat_submit_args)
523
+
524
+ return demo
525
+
526
+ # ============================================================================
527
+ # LAUNCH
528
+ # ============================================================================
529
+
530
+ if __name__ == "__main__":
531
+ demo = create_interface()
532
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7890)
app_studio.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from typing import Optional, Tuple, List, Dict, Any
5
+ import io
6
+ import base64
7
+ from tqdm.auto import tqdm
8
+ from dataclasses import dataclass
9
+ import gradio as gr
10
+ import json
11
+
12
+ from pipeline.deduplication import find_near_duplicates
13
+ from pipeline.featurizer import custom_featurizer
14
+ from pipeline.issues import find_issues
15
+ from pipeline.pipeline import make_step, run_pipeline
16
+
17
+ from ecg_analyzer import ECGAnalyzer
18
+
19
+ # ============================================================================
20
+ # ANALYSIS TASK CONFIGURATION
21
+ # ============================================================================
22
+
23
+ @dataclass
24
+ class TaskConfig:
25
+ """Configuration for each analysis task"""
26
+ name: str
27
+ data_type: str
28
+ requires_params: bool
29
+ param_components: List[Dict[str, Any]]
30
+ output_tabs: List[str]
31
+
32
+ class TaskRegistry:
33
+ """Registry mapping tasks to their configurations"""
34
+
35
+ @staticmethod
36
+ def get_config(data_type: str, task_name: str) -> Optional[TaskConfig]:
37
+ """Get configuration for a specific task"""
38
+ configs = {
39
+ "EHR Data": {
40
+ "Near-Duplicate Detection": TaskConfig(
41
+ name="Near-Duplicate Detection",
42
+ data_type="EHR Data",
43
+ requires_params=True,
44
+ param_components=[
45
+ {"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
46
+ ],
47
+ output_tabs=["original", "processed", "summary"]
48
+ ),
49
+ "Find Mislabeled Data": TaskConfig(
50
+ name="Find Mislabeled Data",
51
+ data_type="EHR Data",
52
+ requires_params=True,
53
+ param_components=[
54
+ {"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
55
+ ],
56
+ output_tabs=["original", "summary"]
57
+ )
58
+ },
59
+ "ECG Data": {
60
+ "ECG Visualization": TaskConfig(
61
+ name="ECG Visualization",
62
+ data_type="ECG Data",
63
+ requires_params=True,
64
+ param_components=[
65
+ {"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_leads"},
66
+ {"type": "checkboxgroup", "label": "Visualization Types", "elem_id": "ecg_viz_types"}
67
+ ],
68
+ output_tabs=["visualization", "summary"]
69
+ ),
70
+ "Statistical Summary": TaskConfig(
71
+ name="Statistical Summary",
72
+ data_type="ECG Data",
73
+ requires_params=True,
74
+ param_components=[
75
+ {"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_stats_leads"}
76
+ ],
77
+ output_tabs=["summary", "visualization"]
78
+ )
79
+ }
80
+ }
81
+ return configs.get(data_type, {}).get(task_name)
82
+
83
+ @staticmethod
84
+ def get_tasks_for_data_type(data_type: str) -> List[str]:
85
+ """Get available tasks for a data type"""
86
+ tasks = {
87
+ "EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
88
+ "ECG Data": ["ECG Visualization", "Statistical Summary"]
89
+ }
90
+ return tasks.get(data_type, [])
91
+
92
+
93
+ # ============================================================================
94
+ # ANALYSIS EXECUTION
95
+ # ============================================================================
96
+
97
+ class AnalysisExecutor:
98
+ """Executes analysis tasks and returns results"""
99
+
100
+ @staticmethod
101
+ def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
102
+ """Execute near-duplicate detection pipeline"""
103
+ try:
104
+ if not label:
105
+ return "⚠ Label column required", {"original": df, "processed": None, "summary": None}
106
+
107
+ bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
108
+ steps = [
109
+ make_step(find_near_duplicates, name="dedup")(progress=bar),
110
+ make_step(custom_featurizer, name="featurize")(
111
+ label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
112
+ ),
113
+ make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
114
+ ]
115
+ results_df, summary_list = run_pipeline(steps, df=df)
116
+ bar.close()
117
+
118
+ return "✓ Near-duplicate detection completed", {
119
+ "original": df, "processed": results_df, "summary": summary_list
120
+ }
121
+ except Exception as e:
122
+ return f"✗ Error: {str(e)}", {"original": df, "processed": None, "summary": None}
123
+
124
+ @staticmethod
125
+ def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
126
+ """Execute mislabeled data detection"""
127
+ try:
128
+ if not label:
129
+ return "⚠ Label column required", {"original": df, "summary": None}
130
+
131
+ summary = {
132
+ "task": "Find Mislabeled Data", "label_column": label, "total_samples": len(df),
133
+ "suspicious_samples": 0, "message": "Mislabeled detection analysis completed"
134
+ }
135
+ return "✓ Mislabeled data analysis completed", {"original": df, "summary": summary}
136
+ except Exception as e:
137
+ return f"✗ Error: {str(e)}", {"original": df, "summary": None}
138
+
139
+ @staticmethod
140
+ def execute_ecg_visualization(df: pd.DataFrame, leads: List[str] = None, viz_types: List[str] = None) -> Tuple[str, Dict[str, Any]]:
141
+ """Execute ECG visualization using ECGAnalyzer"""
142
+ try:
143
+ # Detect available leads
144
+ available_leads = ECGAnalyzer.detect_leads(df)
145
+
146
+ # Use provided leads or default to all available
147
+ if not leads:
148
+ leads = available_leads if available_leads else []
149
+
150
+ if not leads:
151
+ return "⚠ No ECG leads found in data", {"visualization": None, "summary": None}
152
+
153
+ # Default visualization types
154
+ if not viz_types:
155
+ viz_types = ["Signal Waveform", "Histogram"]
156
+
157
+ # Create visualizations
158
+ viz_html = ECGAnalyzer.create_all_visualizations(df, leads, viz_types)
159
+
160
+ # Generate statistics
161
+ stats = ECGAnalyzer.generate_statistics(df, leads)
162
+
163
+ summary = {
164
+ "task": "ECG Visualization",
165
+ "samples": len(df),
166
+ "leads_analyzed": leads,
167
+ "visualizations": viz_types,
168
+ "statistics": stats
169
+ }
170
+
171
+ return "✓ ECG visualization created", {"visualization": viz_html, "summary": summary}
172
+ except Exception as e:
173
+ return f"✗ Error: {str(e)}", {"visualization": None, "summary": None}
174
+
175
+ @staticmethod
176
+ def execute_statistical_summary(df: pd.DataFrame, leads: List[str] = None) -> Tuple[str, Dict[str, Any]]:
177
+ """Execute statistical summary using ECGAnalyzer"""
178
+ try:
179
+ # Detect available leads
180
+ available_leads = ECGAnalyzer.detect_leads(df)
181
+
182
+ # Use provided leads or default to all available
183
+ if not leads:
184
+ leads = available_leads if available_leads else list(df.select_dtypes(include=[np.number]).columns)
185
+
186
+ if not leads:
187
+ return "⚠ No numeric columns found", {"summary": None, "visualization": None}
188
+
189
+ # Generate statistics
190
+ stats = ECGAnalyzer.generate_statistics(df, leads)
191
+
192
+ # Create HTML table for statistics
193
+ html_rows = []
194
+ html_rows.append("<table class='preview-table' style='margin: 20px auto; max-width: 900px;'>")
195
+ html_rows.append("<thead><tr><th>Lead</th><th>Mean</th><th>Std</th><th>Min</th><th>Q25</th><th>Median</th><th>Q75</th><th>Max</th></tr></thead>")
196
+ html_rows.append("<tbody>")
197
+
198
+ for lead, lead_stats in stats.items():
199
+ html_rows.append(f"<tr>")
200
+ html_rows.append(f"<td><strong>{lead}</strong></td>")
201
+ html_rows.append(f"<td>{lead_stats['mean']:.4f}</td>")
202
+ html_rows.append(f"<td>{lead_stats['std']:.4f}</td>")
203
+ html_rows.append(f"<td>{lead_stats['min']:.4f}</td>")
204
+ html_rows.append(f"<td>{lead_stats['q25']:.4f}</td>")
205
+ html_rows.append(f"<td>{lead_stats['median']:.4f}</td>")
206
+ html_rows.append(f"<td>{lead_stats['q75']:.4f}</td>")
207
+ html_rows.append(f"<td>{lead_stats['max']:.4f}</td>")
208
+ html_rows.append(f"</tr>")
209
+
210
+ html_rows.append("</tbody></table>")
211
+ summary_html = f"<div style='overflow-x:auto;'><h3 style='text-align:center;'>Statistical Summary</h3>{''.join(html_rows)}</div>"
212
+
213
+ summary = {
214
+ "task": "Statistical Summary",
215
+ "rows": len(df),
216
+ "leads_analyzed": leads,
217
+ "statistics": stats
218
+ }
219
+
220
+ return "✓ Statistical summary generated", {"summary": summary, "visualization": summary_html}
221
+ except Exception as e:
222
+ return f"✗ Error: {str(e)}", {"summary": None, "visualization": None}
223
+
224
+
225
+ # ============================================================================
226
+ # UI MANAGER - Handles all UI state and updates
227
+ # ============================================================================
228
+
229
+ class UIManager:
230
+ """Manages UI state and dynamic updates"""
231
+
232
+ def __init__(self):
233
+ self.current_df = None
234
+ self.current_data_type = "EHR Data"
235
+ self.chatbot_context = {}
236
+ self.command_map = {
237
+ "data_type": {"ehr": "EHR Data", "ecg": "ECG Data"},
238
+ "task": {
239
+ "deduplication": "Near-Duplicate Detection", "mislabeled": "Find Mislabeled Data",
240
+ "visualize_ecg": "ECG Visualization", "stats": "Statistical Summary"
241
+ }
242
+ }
243
+
244
+ def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
245
+ """Load CSV file"""
246
+ if file is None: return "⚠ No file uploaded", None
247
+ try:
248
+ df = pd.read_csv(file.name)
249
+ self.current_df = df
250
+ return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
251
+ except Exception as e:
252
+ return f"✗ Error: {str(e)}", None
253
+
254
+ def on_file_upload(self, file, data_type: str):
255
+ """Handle file upload - returns updates for all components"""
256
+ status, df = self.load_csv(file)
257
+ if df is None:
258
+ return (
259
+ status, gr.update(value=None), gr.update(choices=[], value=[]),
260
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
261
+ gr.update(choices=[]), gr.update(choices=[]),
262
+ gr.update(choices=[]), gr.update(choices=[], value=[]),
263
+ gr.update(choices=[]), gr.update(choices=[], value=[]),
264
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
265
+ )
266
+
267
+ self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
268
+ available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
269
+ col_choices = list(df.columns)
270
+
271
+ # Detect ECG leads if ECG data
272
+ ecg_leads = ECGAnalyzer.detect_leads(df) if data_type == "ECG Data" else []
273
+ viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
274
+
275
+ return (
276
+ status, gr.update(value=df.head(200)), gr.update(choices=available_tasks, value=[], interactive=True),
277
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
278
+ gr.update(choices=col_choices, value=None), gr.update(choices=col_choices, value=None),
279
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
280
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
281
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
282
+ )
283
+
284
+ def on_data_type_change(self, data_type: str, file):
285
+ """Handle data type change"""
286
+ self.current_data_type = data_type
287
+ if file and self.current_df is not None:
288
+ self.chatbot_context["type"] = data_type
289
+
290
+ # Update ECG lead choices if switching to ECG data
291
+ ecg_leads = ECGAnalyzer.detect_leads(self.current_df) if data_type == "ECG Data" else []
292
+ viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"]
293
+
294
+ return (
295
+ gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
296
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
297
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]),
298
+ gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]),
299
+ f"Data type changed to: {data_type}"
300
+ )
301
+
302
+ return (
303
+ gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]),
304
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
305
+ gr.update(), gr.update(), gr.update(), gr.update(),
306
+ f"Data type changed to: {data_type}"
307
+ )
308
+
309
+ def on_tasks_change(self, selected_tasks: List[str]):
310
+ """Handle task selection change - show/hide parameter groups"""
311
+ show_ndd = "Near-Duplicate Detection" in selected_tasks
312
+ show_mislabel = "Find Mislabeled Data" in selected_tasks
313
+ show_ecg_viz = "ECG Visualization" in selected_tasks
314
+ show_ecg_stats = "Statistical Summary" in selected_tasks and self.current_data_type == "ECG Data"
315
+ return (
316
+ gr.update(visible=show_ndd),
317
+ gr.update(visible=show_mislabel),
318
+ gr.update(visible=show_ecg_viz),
319
+ gr.update(visible=show_ecg_stats)
320
+ )
321
+
322
+ def process_analysis(self, file, data_type: str, selected_tasks: List[str],
323
+ ndd_label: str, mislabel_label: str,
324
+ ecg_viz_leads: List[str], ecg_viz_types: List[str],
325
+ ecg_stats_leads: List[str]):
326
+ """Process analysis tasks based on UI inputs."""
327
+ status, df = self.load_csv(file)
328
+ if df is None:
329
+ return (status, None, None, None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False))
330
+
331
+ params = {
332
+ "ndd_label": ndd_label,
333
+ "mislabel_label": mislabel_label,
334
+ "ecg_viz_leads": ecg_viz_leads,
335
+ "ecg_viz_types": ecg_viz_types,
336
+ "ecg_stats_leads": ecg_stats_leads
337
+ }
338
+ return self._run_analysis(df, data_type, selected_tasks, params)
339
+
340
+ def _run_analysis(self, df: pd.DataFrame, data_type: str, selected_tasks: List[str], params: Dict[str, Any]):
341
+ """Centralized analysis executor, callable from UI or chatbot."""
342
+ if not selected_tasks:
343
+ return (
344
+ "⚠ No tasks selected", df.head(200), None, None, None,
345
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
346
+ )
347
+
348
+ all_tabs = set(); all_results = {"original": df.head(200), "processed": None, "summary": [], "visualization": ""}
349
+ status_messages = []; executor = AnalysisExecutor()
350
+
351
+ for task_name in selected_tasks:
352
+ config = TaskRegistry.get_config(data_type, task_name)
353
+ if not config:
354
+ status_messages.append(f"✗ Unknown task: {task_name}"); continue
355
+ all_tabs.update(config.output_tabs)
356
+
357
+ if task_name == "Near-Duplicate Detection":
358
+ status_msg, results = executor.execute_near_duplicate_detection(df, params.get("ndd_label"))
359
+ elif task_name == "Find Mislabeled Data":
360
+ status_msg, results = executor.execute_find_mislabeled(df, params.get("mislabel_label"))
361
+ elif task_name == "ECG Visualization":
362
+ status_msg, results = executor.execute_ecg_visualization(
363
+ df,
364
+ params.get("ecg_viz_leads"),
365
+ params.get("ecg_viz_types")
366
+ )
367
+ elif task_name == "Statistical Summary":
368
+ if data_type == "ECG Data":
369
+ status_msg, results = executor.execute_statistical_summary(df, params.get("ecg_stats_leads"))
370
+ else:
371
+ status_msg, results = executor.execute_statistical_summary(df)
372
+ else:
373
+ status_msg, results = "✗ Task not implemented", {}
374
+
375
+ status_messages.append(f"{task_name}: {status_msg}")
376
+ if results.get("processed") is not None: all_results["processed"] = results["processed"]
377
+ if results.get("visualization"): all_results["visualization"] += results["visualization"]
378
+ if results.get("summary") is not None: all_results["summary"].append({"task": task_name, "data": results["summary"]})
379
+
380
+ return (
381
+ "\n".join(status_messages), all_results["original"], all_results["processed"],
382
+ all_results["summary"] or None, all_results["visualization"] or None,
383
+ gr.update(visible="original" in all_tabs), gr.update(visible="processed" in all_tabs),
384
+ gr.update(visible="summary" in all_tabs), gr.update(visible="visualization" in all_tabs)
385
+ )
386
+
387
+ def chatbot_respond(self, message: str, history: List):
388
+ """Handle chatbot messages, parsing for commands or responding to queries."""
389
+ history = history or []; df = self.chatbot_context.get("df")
390
+ ui_updates = tuple([gr.update()] * 9) # status, 4 outputs, 4 tabs
391
+
392
+ try:
393
+ command = json.loads(message)
394
+ if isinstance(command, dict) and "task" in command:
395
+ if df is None:
396
+ history.append((message, "Please upload a data file before running a command."))
397
+ return (history, "") + ui_updates
398
+
399
+ data_type = self.command_map["data_type"].get(command.get("data_type", "").lower(), self.current_data_type)
400
+ task_name = self.command_map["task"].get(command.get("task", "").lower())
401
+
402
+ if not task_name:
403
+ response = f"Unknown task: '{command['task']}'. Valid: {list(self.command_map['task'].keys())}"
404
+ history.append((message, response))
405
+ return (history, "") + ui_updates
406
+
407
+ params = command.get("parameters", {})
408
+ analysis_params = {"ndd_label": params.get("label"), "mislabel_label": params.get("label")}
409
+
410
+ analysis_updates = self._run_analysis(df, data_type, [task_name], analysis_params)
411
+ history.append((message, f"Command executed: Running '{task_name}'."))
412
+ return (history, "") + analysis_updates
413
+ except (json.JSONDecodeError, TypeError):
414
+ pass # Not a JSON command, proceed with standard logic
415
+
416
+ if "column" in message.lower():
417
+ response = (f"Dataset has {len(df.columns)} columns: {', '.join(map(str, df.columns))}" if df is not None else "Please upload a file first.")
418
+ elif "row" in message.lower():
419
+ response = f"Dataset has {len(df)} rows." if df is not None else "Please upload a file first."
420
+ elif "help" in message.lower():
421
+ response = "Ask about 'columns' or 'rows'. To run a task, send JSON, e.g., `{\"task\": \"stats\"}` or `{\"task\": \"deduplication\", \"parameters\": {\"label\": \"your_column\"}}`"
422
+ else:
423
+ response = "I can help with data queries or run tasks via JSON commands. Try asking 'help'."
424
+
425
+ history.append((message, response))
426
+ return (history, "") + ui_updates
427
+
428
+
429
+ # ============================================================================
430
+ # GRADIO INTERFACE
431
+ # ============================================================================
432
+
433
+ def create_interface():
434
+ """Build the Gradio interface"""
435
+ ui_manager = UIManager()
436
+ custom_css = """
437
+ * { box-sizing: border-box; } html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
438
+ .gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
439
+ #app-container { height: 100vh; display: flex; flex-direction: column; padding: 0.75rem; gap: 0.75rem; }
440
+ #main-row { flex: 1; min-height: 0; display: flex; gap: 0.75rem; }
441
+ #left-panel { display: flex; flex-direction: column; height: 100%; background: #f9fafb; border-radius: 10px; padding: 0.75rem; gap: 0.5rem; }
442
+ #task-section { flex: 1; min-height: 0; overflow-y: auto; display: flex; flex-direction: column; gap: 0.5rem; }
443
+ #middle-panel, #chat-panel { display: flex; flex-direction: column; height: 100%; }
444
+ #tabs-container { flex: 1; min-height: 0; display: flex; flex-direction: column; }
445
+ #tabs-container .tabitem { flex: 1; min-height: 0; overflow: auto; }
446
+ #chat-history { flex: 1; min-height: 0; overflow-y: auto; margin-bottom: 0.5rem; }
447
+ #chat-input-row { flex-shrink: 0; display: flex; gap: 0.5rem; }
448
+ .preview-table { border-collapse: collapse; width: 100%; font-size: 0.875rem; }
449
+ .preview-table th { background-color: #3498db; color: white; padding: 8px; text-align: left; position: sticky; top: 0; }
450
+ """
451
+
452
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
453
+ with gr.Column(elem_id="app-container"):
454
+ gr.Markdown("# 🏥 Medical Data Analysis Platform")
455
+ with gr.Row():
456
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
457
+ data_type = gr.Dropdown(choices=["EHR Data", "ECG Data"], value="EHR Data", label="Data Type", scale=1)
458
+
459
+ with gr.Row(elem_id="main-row"):
460
+ with gr.Column(scale=2, elem_id="left-panel"):
461
+ with gr.Group(elem_id="task-section"):
462
+ gr.Markdown("#### Analysis Tasks")
463
+ task_selector = gr.CheckboxGroup(choices=TaskRegistry.get_tasks_for_data_type("EHR Data"), label=None)
464
+ with gr.Group(visible=False) as ndd_param_group:
465
+ gr.Markdown("**Near-Duplicate Detection Parameters**")
466
+ ndd_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
467
+ with gr.Group(visible=False) as mislabel_param_group:
468
+ gr.Markdown("**Find Mislabeled Data Parameters**")
469
+ mislabel_label_dropdown = gr.Dropdown(choices=[], label="Label Column")
470
+ with gr.Group(visible=False) as ecg_viz_param_group:
471
+ gr.Markdown("**ECG Visualization Parameters**")
472
+ ecg_viz_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
473
+ ecg_viz_types = gr.CheckboxGroup(
474
+ choices=["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"],
475
+ label="Visualization Types",
476
+ value=["Signal Waveform", "Histogram"]
477
+ )
478
+ with gr.Group(visible=False) as ecg_stats_param_group:
479
+ gr.Markdown("**Statistical Summary Parameters**")
480
+ ecg_stats_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[])
481
+ process_btn = gr.Button("▶ Process", variant="primary")
482
+ status_output = gr.Textbox(label="Status", interactive=False, lines=2)
483
+
484
+ with gr.Column(scale=7, elem_id="middle-panel"):
485
+ with gr.Tabs(elem_id="tabs-container"):
486
+ with gr.TabItem("Original Data", visible=False) as tab_original:
487
+ original_df_output = gr.DataFrame(interactive=False)
488
+ with gr.TabItem("Processed Data", visible=False) as tab_processed:
489
+ processed_df_output = gr.DataFrame(interactive=False)
490
+ with gr.TabItem("Summary", visible=False) as tab_summary:
491
+ summary_output = gr.JSON()
492
+ with gr.TabItem("Visualization", visible=False) as tab_viz:
493
+ viz_output = gr.HTML()
494
+
495
+ with gr.Column(scale=3, elem_id="chat-panel"):
496
+ gr.Markdown("### 💬 AI Assistant")
497
+ chatbot = gr.Chatbot(elem_id="chat-history", height="100%")
498
+ with gr.Row(elem_id="chat-input-row"):
499
+ msg_input = gr.Textbox(placeholder="Ask or send a JSON command...", scale=4, container=False)
500
+ send_btn = gr.Button("Send", scale=1)
501
+
502
+ analysis_outputs = [
503
+ status_output, original_df_output, processed_df_output, summary_output, viz_output,
504
+ tab_original, tab_processed, tab_summary, tab_viz
505
+ ]
506
+
507
+ file_input.change(
508
+ fn=ui_manager.on_file_upload, inputs=[file_input, data_type],
509
+ outputs=[status_output, original_df_output, task_selector,
510
+ ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
511
+ ndd_label_dropdown, mislabel_label_dropdown,
512
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types,
513
+ tab_original, tab_processed, tab_summary, tab_viz]
514
+ )
515
+ data_type.change(
516
+ fn=ui_manager.on_data_type_change, inputs=[data_type, file_input],
517
+ outputs=[task_selector, ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group,
518
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, status_output]
519
+ )
520
+ task_selector.change(
521
+ fn=ui_manager.on_tasks_change, inputs=[task_selector],
522
+ outputs=[ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group]
523
+ )
524
+ process_btn.click(
525
+ fn=ui_manager.process_analysis,
526
+ inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown,
527
+ ecg_viz_leads, ecg_viz_types, ecg_stats_leads],
528
+ outputs=analysis_outputs
529
+ )
530
+
531
+ chat_submit_args = {"fn": ui_manager.chatbot_respond, "inputs": [msg_input, chatbot], "outputs": [chatbot, msg_input] + analysis_outputs}
532
+ send_btn.click(**chat_submit_args)
533
+ msg_input.submit(**chat_submit_args)
534
+
535
+ return demo
536
+
537
+ # ============================================================================
538
+ # LAUNCH
539
+ # ============================================================================
540
+
541
+ if __name__ == "__main__":
542
+ demo = create_interface()
543
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7890)
app_test.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ADDING NEW ANALYSIS TASKS:
3
+ ==========================
4
+
5
+ 1. Add task configuration in TaskRegistry.get_config():
6
+ "Your Data Type": {
7
+ "Your Task Name": TaskConfig(
8
+ name="Your Task Name",
9
+ data_type="Your Data Type",
10
+ requires_params=True, # Set to True if needs parameters
11
+ param_components=[...], # Define parameter types
12
+ output_tabs=["original", "summary", ...] # Which tabs to show
13
+ )
14
+ }
15
+
16
+ 2. Add execution method in AnalysisExecutor:
17
+ @staticmethod
18
+ def execute_your_task(df, param1, param2):
19
+ # Your analysis logic
20
+ return "✓ Done", {
21
+ "original": df,
22
+ "summary": summary_data,
23
+ ...
24
+ }
25
+
26
+ 3. In create_interface(), add parameter group in the LEFT panel:
27
+ with gr.Group(visible=False) as your_task_param_group:
28
+ gr.Markdown("**Your Task Parameters**")
29
+ your_param1 = gr.Slider(minimum=0, maximum=1, label="Threshold")
30
+ your_param2 = gr.Dropdown(choices=["A", "B"], label="Option")
31
+
32
+ 4. In UIManager.on_tasks_change()
33
+ """
34
+
35
+ import pandas as pd
36
+ import matplotlib.pyplot as plt
37
+ import numpy as np
38
+ from typing import Optional, Tuple, List, Dict, Any
39
+ import io
40
+ import base64
41
+ from tqdm.auto import tqdm
42
+ from dataclasses import dataclass
43
+ import gradio as gr
44
+ # Direct imports (your existing modules)
45
+
46
+
47
+ from pipeline.deduplication import find_near_duplicates
48
+ from pipeline.featurizer import custom_featurizer
49
+ from pipeline.issues import find_issues
50
+ from pipeline.pipeline import make_step, run_pipeline
51
+
52
+
53
+
54
+ from pipeline import make_step, run_pipeline
55
+
56
+ # ============================================================================
57
+ # ANALYSIS TASK CONFIGURATION
58
+ # ============================================================================
59
+
60
+ @dataclass
61
+ class TaskConfig:
62
+ """Configuration for each analysis task"""
63
+ name: str
64
+ data_type: str
65
+ requires_params: bool # Does it need additional parameters?
66
+ param_components: List[Dict[str, Any]] # List of parameter UI components
67
+ output_tabs: List[str] # Which tabs to show: "original", "processed", "summary", "visualization"
68
+
69
+ class TaskRegistry:
70
+ """Registry mapping tasks to their configurations"""
71
+
72
+ @staticmethod
73
+ def get_config(data_type: str, task_name: str) -> TaskConfig:
74
+ """Get configuration for a specific task"""
75
+ configs = {
76
+ "EHR Data": {
77
+ "Near-Duplicate Detection": TaskConfig(
78
+ name="Near-Duplicate Detection",
79
+ data_type="EHR Data",
80
+ requires_params=True,
81
+ param_components=[
82
+ {"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"}
83
+ ],
84
+ output_tabs=["original", "processed", "summary"]
85
+ ),
86
+ "Find Mislabeled Data": TaskConfig(
87
+ name="Find Mislabeled Data",
88
+ data_type="EHR Data",
89
+ requires_params=True,
90
+ param_components=[
91
+ {"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"}
92
+ ],
93
+ output_tabs=["original", "summary"]
94
+ )
95
+ },
96
+ "ECG Data": {
97
+ "ECG Visualization": TaskConfig(
98
+ name="ECG Visualization",
99
+ data_type="ECG Data",
100
+ requires_params=False,
101
+ param_components=[],
102
+ output_tabs=["visualization", "summary"]
103
+ ),
104
+ "Statistical Summary": TaskConfig(
105
+ name="Statistical Summary",
106
+ data_type="ECG Data",
107
+ requires_params=False,
108
+ param_components=[],
109
+ output_tabs=["summary"]
110
+ )
111
+ }
112
+ }
113
+ return configs.get(data_type, {}).get(task_name)
114
+
115
+ @staticmethod
116
+ def get_tasks_for_data_type(data_type: str) -> List[str]:
117
+ """Get available tasks for a data type"""
118
+ tasks = {
119
+ "EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"],
120
+ "ECG Data": ["ECG Visualization", "Statistical Summary"]
121
+ }
122
+ return tasks.get(data_type, [])
123
+
124
+
125
+ # ============================================================================
126
+ # ANALYSIS EXECUTION
127
+ # ============================================================================
128
+
129
+ class AnalysisExecutor:
130
+ """Executes analysis tasks and returns results"""
131
+
132
+ @staticmethod
133
+ def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
134
+ """Execute near-duplicate detection pipeline"""
135
+ try:
136
+ if not label:
137
+ return "⚠ Label column required", {
138
+ "original": df,
139
+ "processed": None,
140
+ "summary": None
141
+ }
142
+
143
+ bar = tqdm(total=100, leave=False, desc="Pipeline Progress")
144
+ steps = [
145
+ make_step(find_near_duplicates, name="dedup")(progress=bar),
146
+ make_step(custom_featurizer, name="featurize")(
147
+ label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar
148
+ ),
149
+ make_step(find_issues, name="find_label_issues")(label=label, progress=bar),
150
+ ]
151
+ results_df, summary_list = run_pipeline(steps, df=df)
152
+ bar.close()
153
+
154
+ return "✓ Near-duplicate detection completed", {
155
+ "original": df,
156
+ "processed": results_df,
157
+ "summary": summary_list
158
+ }
159
+ except Exception as e:
160
+ return f"✗ Error: {str(e)}", {
161
+ "original": df,
162
+ "processed": None,
163
+ "summary": None
164
+ }
165
+
166
+ @staticmethod
167
+ def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]:
168
+ """Execute mislabeled data detection"""
169
+ try:
170
+ if not label:
171
+ return "⚠ Label column required", {
172
+ "original": df,
173
+ "summary": None
174
+ }
175
+
176
+ # Placeholder for actual mislabeled detection logic
177
+ summary = {
178
+ "task": "Find Mislabeled Data",
179
+ "label_column": label,
180
+ "total_samples": len(df),
181
+ "suspicious_samples": 0,
182
+ "message": "Mislabeled detection analysis completed"
183
+ }
184
+
185
+ return "✓ Mislabeled data analysis completed", {
186
+ "original": df,
187
+ "summary": summary
188
+ }
189
+ except Exception as e:
190
+ return f"✗ Error: {str(e)}", {
191
+ "original": df,
192
+ "summary": None
193
+ }
194
+
195
+ @staticmethod
196
+ def execute_ecg_visualization(df: pd.DataFrame) -> Tuple[str, Dict[str, Any]]:
197
+ """Execute ECG visualization"""
198
+ try:
199
+ fig, ax = plt.subplots(figsize=(12, 6))
200
+ if len(df.columns) > 1:
201
+ for col in df.columns[1:]:
202
+ ax.plot(df.iloc[:, 0], df[col], label=str(col), linewidth=0.8)
203
+ ax.set_xlabel('Time (ms)')
204
+ ax.set_ylabel('Amplitude (mV)')
205
+ ax.set_title('ECG Signal Visualization')
206
+ ax.legend()
207
+ else:
208
+ ax.plot(df.iloc[:, 0], linewidth=0.8)
209
+ ax.set_xlabel('Sample Index')
210
+ ax.set_ylabel('Amplitude')
211
+ ax.set_title('ECG Signal')
212
+ ax.grid(True, alpha=0.3)
213
+ plt.tight_layout()
214
+
215
+ # Convert plot to base64
216
+ buf = io.BytesIO()
217
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
218
+ buf.seek(0)
219
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
220
+ plt.close(fig)
221
+
222
+ viz_html = f'<div style="text-align:center;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>'
223
+
224
+ summary = {
225
+ "task": "ECG Visualization",
226
+ "samples": len(df),
227
+ "channels": len(df.columns) - 1 if len(df.columns) > 1 else 1
228
+ }
229
+
230
+ return "✓ ECG visualization created", {
231
+ "visualization": viz_html,
232
+ "summary": summary
233
+ }
234
+ except Exception as e:
235
+ return f"✗ Error: {str(e)}", {
236
+ "visualization": None,
237
+ "summary": None
238
+ }
239
+
240
+ @staticmethod
241
+ def execute_statistical_summary(df: pd.DataFrame) -> Tuple[str, Dict[str, Any]]:
242
+ """Execute statistical summary"""
243
+ try:
244
+ stats = df.describe().to_html(classes='preview-table')
245
+ summary_html = f"<h3>Statistical Summary</h3><div style='overflow-x:auto;'>{stats}</div>"
246
+
247
+ summary = {
248
+ "task": "Statistical Summary",
249
+ "rows": len(df),
250
+ "columns": len(df.columns),
251
+ "numeric_columns": len(df.select_dtypes(include=[np.number]).columns),
252
+ "html": summary_html
253
+ }
254
+
255
+ return "✓ Statistical summary generated", {
256
+ "summary": summary,
257
+ "visualization": summary_html
258
+ }
259
+ except Exception as e:
260
+ return f"✗ Error: {str(e)}", {
261
+ "summary": None,
262
+ "visualization": None
263
+ }
264
+
265
+
266
+ # ============================================================================
267
+ # UI MANAGER - Handles all UI state and updates
268
+ # ============================================================================
269
+
270
+ class UIManager:
271
+ """Manages UI state and dynamic updates"""
272
+
273
+ def __init__(self):
274
+ self.current_df = None
275
+ self.current_data_type = "EHR Data"
276
+ self.chatbot_context = {}
277
+
278
+ def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]:
279
+ """Load CSV file"""
280
+ if file is None:
281
+ return "⚠ No file uploaded", None
282
+ try:
283
+ df = pd.read_csv(file.name)
284
+ self.current_df = df
285
+ return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df
286
+ except Exception as e:
287
+ return f"✗ Error: {str(e)}", None
288
+
289
+ def on_file_upload(self, file, data_type: str):
290
+ """Handle file upload - returns updates for all components"""
291
+ status, df = self.load_csv(file)
292
+
293
+ if df is None:
294
+ return (
295
+ status, # status
296
+ gr.update(value=None), # original_df
297
+ gr.update(choices=[], value=[], interactive=True), # task_checkboxes
298
+ gr.update(visible=False), # ndd_param_group
299
+ gr.update(visible=False), # mislabel_param_group
300
+ gr.update(choices=[]), # ndd_label_dropdown
301
+ gr.update(choices=[]), # mislabel_label_dropdown
302
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), # tabs
303
+ )
304
+
305
+ self.chatbot_context = {"file": file.name, "type": data_type, "df": df}
306
+ available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
307
+ col_choices = list(df.columns)
308
+
309
+ return (
310
+ status,
311
+ gr.update(value=df.head(200)),
312
+ gr.update(choices=available_tasks, value=[], interactive=True),
313
+ gr.update(visible=False),
314
+ gr.update(visible=False),
315
+ gr.update(choices=col_choices, value=None),
316
+ gr.update(choices=col_choices, value=None),
317
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
318
+ )
319
+
320
+ def on_data_type_change(self, data_type: str, file):
321
+ """Handle data type change"""
322
+ self.current_data_type = data_type
323
+ available_tasks = TaskRegistry.get_tasks_for_data_type(data_type)
324
+
325
+ if file and self.current_df is not None:
326
+ self.chatbot_context["type"] = data_type
327
+
328
+ return (
329
+ gr.update(choices=available_tasks, value=[]), # Reset task selection
330
+ gr.update(visible=False), # Hide ndd params
331
+ gr.update(visible=False), # Hide mislabel params
332
+ f"Data type changed to: {data_type}"
333
+ )
334
+
335
+ def on_tasks_change(self, selected_tasks: List[str], data_type: str, df_columns: List[str]):
336
+ """Handle task selection change - show/hide parameter groups for all selected tasks"""
337
+ if not selected_tasks:
338
+ # Hide all parameter groups
339
+ return (
340
+ gr.update(visible=False), # ndd_param_group
341
+ gr.update(visible=False), # mislabel_param_group
342
+ )
343
+
344
+ # Check which tasks need which parameters
345
+ show_ndd_params = "Near-Duplicate Detection" in selected_tasks
346
+ show_mislabel_params = "Find Mislabeled Data" in selected_tasks
347
+
348
+ return (
349
+ gr.update(visible=show_ndd_params),
350
+ gr.update(visible=show_mislabel_params),
351
+ )
352
+
353
+ def process_analysis(self, file, data_type: str, selected_tasks: List[str], ndd_label: str, mislabel_label: str):
354
+ """Process selected analysis tasks - handles multiple tasks"""
355
+ status, df = self.load_csv(file)
356
+
357
+ if df is None:
358
+ return (
359
+ status,
360
+ None, None, None, None,
361
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
362
+ )
363
+
364
+ if not selected_tasks:
365
+ return (
366
+ "⚠ No tasks selected",
367
+ df.head(200), None, None, None,
368
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
369
+ )
370
+
371
+ # Track which tabs to show (union of all selected tasks)
372
+ all_tabs = set()
373
+ all_results = {
374
+ "original": df.head(200),
375
+ "processed": None,
376
+ "summary": [],
377
+ "visualization": ""
378
+ }
379
+
380
+ status_messages = []
381
+ executor = AnalysisExecutor()
382
+
383
+ # Process each selected task
384
+ for task_name in selected_tasks:
385
+ config = TaskRegistry.get_config(data_type, task_name)
386
+
387
+ if not config:
388
+ status_messages.append(f"✗ Unknown task: {task_name}")
389
+ continue
390
+
391
+ # Add this task's tabs to the set
392
+ all_tabs.update(config.output_tabs)
393
+
394
+ # Execute the task with appropriate parameters
395
+ if task_name == "Near-Duplicate Detection":
396
+ status_msg, results = executor.execute_near_duplicate_detection(df, ndd_label)
397
+ status_messages.append(f"{task_name}: {status_msg}")
398
+ if results.get("processed") is not None:
399
+ all_results["processed"] = results["processed"]
400
+ if results.get("summary") is not None:
401
+ all_results["summary"].append({"task": task_name, "data": results["summary"]})
402
+
403
+ elif task_name == "Find Mislabeled Data":
404
+ status_msg, results = executor.execute_find_mislabeled(df, mislabel_label)
405
+ status_messages.append(f"{task_name}: {status_msg}")
406
+ if results.get("summary") is not None:
407
+ all_results["summary"].append({"task": task_name, "data": results["summary"]})
408
+
409
+ elif task_name == "ECG Visualization":
410
+ status_msg, results = executor.execute_ecg_visualization(df)
411
+ status_messages.append(f"{task_name}: {status_msg}")
412
+ if results.get("visualization"):
413
+ all_results["visualization"] += results["visualization"]
414
+ if results.get("summary") is not None:
415
+ all_results["summary"].append({"task": task_name, "data": results["summary"]})
416
+
417
+ elif task_name == "Statistical Summary":
418
+ status_msg, results = executor.execute_statistical_summary(df)
419
+ status_messages.append(f"{task_name}: {status_msg}")
420
+ if results.get("visualization"):
421
+ all_results["visualization"] += results["visualization"]
422
+ if results.get("summary") is not None:
423
+ all_results["summary"].append({"task": task_name, "data": results["summary"]})
424
+
425
+ # Format final outputs based on all_tabs
426
+ show_original = gr.update(visible="original" in all_tabs)
427
+ show_processed = gr.update(visible="processed" in all_tabs)
428
+ show_summary = gr.update(visible="summary" in all_tabs)
429
+ show_viz = gr.update(visible="visualization" in all_tabs)
430
+
431
+ # Combine status messages
432
+ final_status = "\n".join(status_messages)
433
+
434
+ return (
435
+ final_status,
436
+ all_results["original"],
437
+ all_results["processed"],
438
+ all_results["summary"] if all_results["summary"] else None,
439
+ all_results["visualization"] if all_results["visualization"] else None,
440
+ show_original,
441
+ show_processed,
442
+ show_summary,
443
+ show_viz
444
+ )
445
+
446
+ def _format_results(self, status: str, output_tabs: List[str], results: Dict[str, Any]):
447
+ """Format results based on output tabs configuration"""
448
+ # Default all outputs to None
449
+ original_out = None
450
+ processed_out = None
451
+ summary_out = None
452
+ viz_out = None
453
+
454
+ # Default all tabs hidden
455
+ show_original = gr.update(visible=False)
456
+ show_processed = gr.update(visible=False)
457
+ show_summary = gr.update(visible=False)
458
+ show_viz = gr.update(visible=False)
459
+
460
+ # Populate based on output_tabs
461
+ if "original" in output_tabs:
462
+ original_out = results.get("original")
463
+ show_original = gr.update(visible=True)
464
+
465
+ if "processed" in output_tabs:
466
+ processed_out = results.get("processed")
467
+ show_processed = gr.update(visible=True)
468
+
469
+ if "summary" in output_tabs:
470
+ summary_data = results.get("summary")
471
+ if results.get("summary_html"): # For statistical summary
472
+ summary_out = results.get("summary_html")
473
+ else:
474
+ summary_out = summary_data
475
+ show_summary = gr.update(visible=True)
476
+
477
+ if "visualization" in output_tabs:
478
+ viz_out = results.get("visualization")
479
+ show_viz = gr.update(visible=True)
480
+
481
+ return (
482
+ status,
483
+ original_out,
484
+ processed_out,
485
+ summary_out,
486
+ viz_out,
487
+ show_original,
488
+ show_processed,
489
+ show_summary,
490
+ show_viz
491
+ )
492
+
493
+ def chatbot_respond(self, message: str, history: List):
494
+ """Chatbot response"""
495
+ if history is None:
496
+ history = []
497
+ if not message or message.strip() == "":
498
+ return history, ""
499
+
500
+ df = self.chatbot_context.get("df")
501
+
502
+ if "column" in message.lower():
503
+ response = (f"Dataset has {len(df.columns)} columns: {', '.join(map(str, df.columns))}"
504
+ if df is not None else "Please upload a file first.")
505
+ elif "row" in message.lower():
506
+ response = f"Dataset has {len(df)} rows." if df is not None else "Please upload a file first."
507
+ elif "help" in message.lower():
508
+ response = "Ask about 'columns', 'rows', or your data analysis."
509
+ else:
510
+ response = f"You asked: '{message}'. Analyzing: {self.chatbot_context.get('type', 'No data')}."
511
+
512
+ history.append((message, response))
513
+ return history, ""
514
+
515
+
516
+ # ============================================================================
517
+ # GRADIO INTERFACE
518
+ # ============================================================================
519
+
520
+ def create_interface():
521
+ """Build the Gradio interface"""
522
+
523
+ ui_manager = UIManager()
524
+
525
+ custom_css = """
526
+ * { box-sizing: border-box; }
527
+ html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; }
528
+ .gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; }
529
+
530
+ #app-container {
531
+ height: 100vh;
532
+ display: flex;
533
+ flex-direction: column;
534
+ padding: 0.75rem;
535
+ gap: 0.75rem;
536
+ }
537
+
538
+ #main-row {
539
+ flex: 1;
540
+ min-height: 0;
541
+ display: flex;
542
+ gap: 0.75rem;
543
+ }
544
+
545
+ #left-panel {
546
+ display: flex;
547
+ flex-direction: column;
548
+ height: 100%;
549
+ background: #f9fafb;
550
+ border-radius: 10px;
551
+ padding: 0.75rem;
552
+ gap: 0.5rem;
553
+ }
554
+
555
+ #task-section {
556
+ flex: 1;
557
+ min-height: 0;
558
+ overflow-y: auto;
559
+ display: flex;
560
+ flex-direction: column;
561
+ gap: 0.5rem;
562
+ }
563
+
564
+ #action-section {
565
+ flex-shrink: 0;
566
+ }
567
+
568
+ #middle-panel {
569
+ display: flex;
570
+ flex-direction: column;
571
+ height: 100%;
572
+ min-width: 0;
573
+ }
574
+
575
+ #tabs-container {
576
+ flex: 1;
577
+ min-height: 0;
578
+ display: flex;
579
+ flex-direction: column;
580
+ }
581
+
582
+ #tabs-container .tabs {
583
+ height: 100%;
584
+ display: flex;
585
+ flex-direction: column;
586
+ }
587
+
588
+ #tabs-container .tab-nav {
589
+ flex-shrink: 0;
590
+ }
591
+
592
+ #tabs-container .tabitem {
593
+ flex: 1;
594
+ min-height: 0;
595
+ overflow: auto;
596
+ }
597
+
598
+ #chat-panel {
599
+ display: flex;
600
+ flex-direction: column;
601
+ height: 100%;
602
+ background: #f9fafb;
603
+ border-radius: 10px;
604
+ padding: 0.75rem;
605
+ }
606
+
607
+ #chat-header {
608
+ flex-shrink: 0;
609
+ margin-bottom: 0.5rem;
610
+ }
611
+
612
+ #chat-history {
613
+ flex: 1;
614
+ min-height: 0;
615
+ overflow-y: auto;
616
+ margin-bottom: 0.5rem;
617
+ }
618
+
619
+ #chat-input-row {
620
+ flex-shrink: 0;
621
+ display: flex;
622
+ gap: 0.5rem;
623
+ align-items: flex-end;
624
+ }
625
+
626
+ #chat-input-row .textbox {
627
+ flex: 1;
628
+ }
629
+
630
+ #chat-input-row button {
631
+ flex-shrink: 0;
632
+ }
633
+
634
+ .preview-table {
635
+ border-collapse: collapse;
636
+ width: 100%;
637
+ font-size: 0.875rem;
638
+ }
639
+ .preview-table th {
640
+ background-color: #3498db;
641
+ color: white;
642
+ padding: 8px;
643
+ text-align: left;
644
+ position: sticky;
645
+ top: 0;
646
+ }
647
+ .preview-table td {
648
+ padding: 6px;
649
+ border-bottom: 1px solid #ddd;
650
+ }
651
+
652
+ .compact-header { font-size: 0.95rem; margin: 0; }
653
+ """
654
+
655
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo:
656
+
657
+ with gr.Column(elem_id="app-container"):
658
+ # Header
659
+ gr.Markdown("# 🏥 Medical Data Analysis Platform", elem_classes=["compact-header"])
660
+
661
+ # Top controls
662
+ with gr.Row():
663
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2)
664
+ data_type = gr.Dropdown(
665
+ choices=["EHR Data", "ECG Data"],
666
+ value="EHR Data",
667
+ label="Data Type",
668
+ scale=1
669
+ )
670
+
671
+ # Main row with 3 columns
672
+ with gr.Row(elem_id="main-row"):
673
+
674
+ # LEFT: Analysis tasks
675
+ with gr.Column(scale=2, elem_id="left-panel"):
676
+ with gr.Group(elem_id="task-section"):
677
+ gr.Markdown("#### Analysis Tasks")
678
+ task_selector = gr.CheckboxGroup(
679
+ choices=TaskRegistry.get_tasks_for_data_type("EHR Data"),
680
+ label=None,
681
+ elem_id="task-checkboxes"
682
+ )
683
+
684
+ # Parameter groups (conditionally visible) - one for each task that needs params
685
+ with gr.Group(visible=False) as ndd_param_group:
686
+ gr.Markdown("**Near-Duplicate Detection Parameters**")
687
+ ndd_label_dropdown = gr.Dropdown(
688
+ choices=[],
689
+ label="Label Column",
690
+ elem_id="ndd-label-dropdown"
691
+ )
692
+
693
+ with gr.Group(visible=False) as mislabel_param_group:
694
+ gr.Markdown("**Find Mislabeled Data Parameters**")
695
+ mislabel_label_dropdown = gr.Dropdown(
696
+ choices=[],
697
+ label="Label Column",
698
+ elem_id="mislabel-label-dropdown"
699
+ )
700
+
701
+ with gr.Group(elem_id="action-section"):
702
+ process_btn = gr.Button("▶ Process", variant="primary", size="lg")
703
+ status_output = gr.Textbox(label="Status", interactive=False, lines=2)
704
+
705
+ # MIDDLE: Results display
706
+ with gr.Column(scale=7, elem_id="middle-panel"):
707
+ with gr.Tabs(elem_id="tabs-container") as result_tabs:
708
+ with gr.TabItem("Original Data", visible=False) as tab_original:
709
+ original_df_output = gr.Dataframe(interactive=False)
710
+
711
+ with gr.TabItem("Processed Data", visible=False) as tab_processed:
712
+ processed_df_output = gr.Dataframe(interactive=False)
713
+
714
+ with gr.TabItem("Summary", visible=False) as tab_summary:
715
+ summary_output = gr.JSON()
716
+
717
+ with gr.TabItem("Visualization", visible=False) as tab_viz:
718
+ viz_output = gr.HTML()
719
+
720
+ # RIGHT: Chat
721
+ with gr.Column(scale=3, elem_id="chat-panel"):
722
+ gr.Markdown("### 💬 AI Assistant", elem_id="chat-header")
723
+ chatbot = gr.Chatbot(elem_id="chat-history", height=None, label=None)
724
+ with gr.Row(elem_id="chat-input-row"):
725
+ msg_input = gr.Textbox(
726
+ placeholder="Ask about your data...",
727
+ label="",
728
+ scale=4,
729
+ container=False,
730
+ lines=1
731
+ )
732
+ send_btn = gr.Button("Send", scale=1, size="sm")
733
+
734
+ # ============ EVENT HANDLERS ============
735
+
736
+ # File upload
737
+ file_input.change(
738
+ fn=ui_manager.on_file_upload,
739
+ inputs=[file_input, data_type],
740
+ outputs=[
741
+ status_output,
742
+ original_df_output,
743
+ task_selector,
744
+ ndd_param_group,
745
+ mislabel_param_group,
746
+ ndd_label_dropdown,
747
+ mislabel_label_dropdown,
748
+ tab_original, tab_processed, tab_summary, tab_viz,
749
+ ]
750
+ )
751
+
752
+ # Data type change
753
+ data_type.change(
754
+ fn=ui_manager.on_data_type_change,
755
+ inputs=[data_type, file_input],
756
+ outputs=[task_selector, ndd_param_group, mislabel_param_group, status_output]
757
+ )
758
+
759
+ # Task selection change
760
+ task_selector.change(
761
+ fn=ui_manager.on_tasks_change,
762
+ inputs=[task_selector, data_type, ndd_label_dropdown],
763
+ outputs=[ndd_param_group, mislabel_param_group]
764
+ )
765
+
766
+ # Process button
767
+ process_btn.click(
768
+ fn=ui_manager.process_analysis,
769
+ inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown],
770
+ outputs=[
771
+ status_output,
772
+ original_df_output,
773
+ processed_df_output,
774
+ summary_output,
775
+ viz_output,
776
+ tab_original,
777
+ tab_processed,
778
+ tab_summary,
779
+ tab_viz
780
+ ]
781
+ )
782
+
783
+ # Chat
784
+ send_btn.click(
785
+ fn=ui_manager.chatbot_respond,
786
+ inputs=[msg_input, chatbot],
787
+ outputs=[chatbot, msg_input]
788
+ )
789
+
790
+ msg_input.submit(
791
+ fn=ui_manager.chatbot_respond,
792
+ inputs=[msg_input, chatbot],
793
+ outputs=[chatbot, msg_input]
794
+ )
795
+
796
+ return demo
797
+
798
+
799
+ # ============================================================================
800
+ # LAUNCH
801
+ # ============================================================================
802
+
803
+ if __name__ == "__main__":
804
+ demo = create_interface()
805
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
806
+
807
+
808
+ # ============================================================================
809
+ # HOW TO EXTEND
810
+ # ============================================================================
811
+
812
+ """
813
+ ADDING NEW ANALYSIS TASKS:
814
+ ==========================
815
+
816
+ 1. Add task configuration in TaskRegistry:
817
+ - Define name, data_type, requires_params, param_components, output_tabs
818
+
819
+ 2. Add execution method in AnalysisExecutor:
820
+ - Create execute_your_task() method
821
+ - Return (status_message, results_dict)
822
+
823
+ 3. Add condition in UIManager.process_analysis():
824
+ - elif task_name == "Your Task Name":
825
+ - status_msg, results = executor.execute_your_task(df, params)
826
+ - return self._format_results(status_msg, config.output_tabs, results)
827
+
828
+
829
+ EXAMPLE - Adding a new ECG task with parameters:
830
+ ================================================
831
+
832
+ In TaskRegistry.get_config():
833
+ "ECG Data": {
834
+ ...
835
+ "ECG Quality Check": TaskConfig(
836
+ name="ECG Quality Check",
837
+ data_type="ECG Data",
838
+ requires_params=True,
839
+ param_components=[
840
+ {"type": "slider", "label": "Quality Threshold", "elem_id": "quality_thresh"}
841
+ ],
842
+ output_tabs=["original", "summary", "visualization"]
843
+ )
844
+ }
845
+
846
+ In AnalysisExecutor:
847
+ @staticmethod
848
+ def execute_ecg_quality_check(df: pd.DataFrame, threshold: float) -> Tuple[str, Dict[str, Any]]:
849
+ # Your analysis logic here
850
+ return "✓ Quality check completed", {
851
+ "original": df,
852
+ "summary": quality_summary,
853
+ "visualization": viz_html
854
+ }
855
+
856
+ In UIManager.process_analysis():
857
+ elif task_name == "ECG Quality Check":
858
+ status_msg, results = executor.execute_ecg_quality_check(df, float(param_value))
859
+ return self._format_results(status_msg, config.output_tabs, results)
860
+
861
+
862
+ The architecture is now:
863
+ - Simple and clean
864
+ - Easy to extend with new tasks
865
+ - Dynamic UI based on task configuration
866
+ - Fixed chat scrolling issue
867
+ """
backend_client.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend_client.py
2
+ from __future__ import annotations
3
+ import requests
4
+ from typing import Optional, List, Tuple
5
+
6
+ class LangGraphClient:
7
+ """
8
+ Minimal client to talk to your LangGraph app running on http://localhost:8010.
9
+ Adjust ENDPOINTS to match your server routes/payloads.
10
+ """
11
+
12
+ def __init__(self, base_url: str = "http://localhost:8010"):
13
+ self.base_url = base_url.rstrip("/")
14
+ # Example endpoints — change if your server differs
15
+ self.endpoints = {
16
+ "chat": f"{self.base_url}/chat" # expects JSON {session_id, message, history?}
17
+ }
18
+
19
+ def send_message(
20
+ self,
21
+ message: str,
22
+ session_id: Optional[str] = "default",
23
+ history: Optional[List[Tuple[str, str]]] = None,
24
+ timeout: int = 30
25
+ ) -> str:
26
+ """
27
+ Sends a message to the backend and returns the assistant reply text.
28
+ """
29
+ payload = {
30
+ "session_id": session_id,
31
+ "message": message,
32
+ "history": history or [] # [[user, assistant], ...] if your server uses it
33
+ }
34
+ try:
35
+ r = requests.post(self.endpoints["chat"], json=payload, timeout=timeout)
36
+ r.raise_for_status()
37
+ data = r.json()
38
+ # Expecting {"reply": "..."}; adjust if your API returns a different shape
39
+ reply = data.get("reply") or data.get("message") or data.get("text") or ""
40
+ return str(reply)
41
+ except requests.RequestException as e:
42
+ return f"[Backend error] {e}"
data/Aubrie.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/JS00001_filtered.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/Lisette.csv ADDED
The diff for this file is too large to render. See raw diff
 
ecg_analyzer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ECG Analysis Module
3
+ Provides modular ECG visualization and analysis functions
4
+ """
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import io
9
+ import base64
10
+ from typing import List, Optional, Tuple, Dict, Any
11
+
12
+ class ECGAnalyzer:
13
+ """Handles ECG data analysis and visualization"""
14
+
15
+ # Standard 12-lead ECG leads
16
+ STANDARD_LEADS = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
17
+
18
+ @staticmethod
19
+ def detect_leads(df: pd.DataFrame) -> List[str]:
20
+ """Detect available ECG leads in the dataframe"""
21
+ available_leads = []
22
+ for lead in ECGAnalyzer.STANDARD_LEADS:
23
+ if lead in df.columns:
24
+ available_leads.append(lead)
25
+ return available_leads
26
+
27
+ @staticmethod
28
+ def detect_time_column(df: pd.DataFrame) -> Optional[str]:
29
+ """Detect the time column in the dataframe"""
30
+ time_candidates = ['time', 'Time', 'TIME', 'timestamp', 'sample']
31
+ for col in time_candidates:
32
+ if col in df.columns:
33
+ return col
34
+ # If no explicit time column, use index
35
+ return None
36
+
37
+ @staticmethod
38
+ def create_signal_plot(df: pd.DataFrame, leads: List[str], time_col: Optional[str] = None) -> str:
39
+ """Create ECG signal waveform plot"""
40
+ fig, ax = plt.subplots(figsize=(14, 6))
41
+
42
+ if time_col and time_col in df.columns:
43
+ x_data = df[time_col]
44
+ x_label = 'Time (ms)' if 'time' in time_col.lower() else time_col
45
+ else:
46
+ x_data = df.index
47
+ x_label = 'Sample Index'
48
+
49
+ colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
50
+
51
+ for idx, lead in enumerate(leads):
52
+ if lead in df.columns:
53
+ ax.plot(x_data, df[lead], label=f'Lead {lead}',
54
+ linewidth=1.2, alpha=0.8, color=colors[idx])
55
+
56
+ ax.set_xlabel(x_label, fontsize=11)
57
+ ax.set_ylabel('Amplitude (mV)', fontsize=11)
58
+ ax.set_title('ECG Signal Waveform', fontsize=13, fontweight='bold')
59
+ ax.legend(loc='upper right', fontsize=9, ncol=min(4, len(leads)))
60
+ ax.grid(True, alpha=0.3, linestyle='--')
61
+ plt.tight_layout()
62
+
63
+ return ECGAnalyzer._fig_to_base64(fig)
64
+
65
+ @staticmethod
66
+ def create_histogram(df: pd.DataFrame, leads: List[str]) -> str:
67
+ """Create histogram of signal amplitudes"""
68
+ fig, ax = plt.subplots(figsize=(14, 6))
69
+
70
+ colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
71
+
72
+ for idx, lead in enumerate(leads):
73
+ if lead in df.columns:
74
+ ax.hist(df[lead].dropna(), bins=50, alpha=0.6,
75
+ label=f'Lead {lead}', color=colors[idx], edgecolor='black')
76
+
77
+ ax.set_xlabel('Amplitude (mV)', fontsize=11)
78
+ ax.set_ylabel('Frequency', fontsize=11)
79
+ ax.set_title('Distribution of Signal Amplitudes', fontsize=13, fontweight='bold')
80
+ ax.legend(loc='upper right', fontsize=9)
81
+ ax.grid(True, alpha=0.3, axis='y')
82
+ plt.tight_layout()
83
+
84
+ return ECGAnalyzer._fig_to_base64(fig)
85
+
86
+ @staticmethod
87
+ def create_scatter_plot(df: pd.DataFrame, lead_x: str = 'I', lead_y: str = 'II') -> str:
88
+ """Create scatter plot comparing two leads"""
89
+ fig, ax = plt.subplots(figsize=(8, 8))
90
+
91
+ if lead_x in df.columns and lead_y in df.columns:
92
+ ax.scatter(df[lead_x], df[lead_y], alpha=0.5, s=10, c='steelblue')
93
+ ax.set_xlabel(f'Lead {lead_x} Amplitude (mV)', fontsize=11)
94
+ ax.set_ylabel(f'Lead {lead_y} Amplitude (mV)', fontsize=11)
95
+ ax.set_title(f'Lead {lead_x} vs Lead {lead_y}', fontsize=13, fontweight='bold')
96
+ ax.grid(True, alpha=0.3)
97
+
98
+ # Add correlation coefficient
99
+ correlation = df[[lead_x, lead_y]].corr().iloc[0, 1]
100
+ ax.text(0.05, 0.95, f'Correlation: {correlation:.3f}',
101
+ transform=ax.transAxes, fontsize=10, verticalalignment='top',
102
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
103
+ else:
104
+ ax.text(0.5, 0.5, 'Selected leads not available',
105
+ ha='center', va='center', fontsize=12)
106
+
107
+ plt.tight_layout()
108
+ return ECGAnalyzer._fig_to_base64(fig)
109
+
110
+ @staticmethod
111
+ def create_rolling_average(df: pd.DataFrame, leads: List[str],
112
+ time_col: Optional[str] = None, window: int = 100) -> str:
113
+ """Create rolling average plot"""
114
+ fig, ax = plt.subplots(figsize=(14, 6))
115
+
116
+ if time_col and time_col in df.columns:
117
+ x_data = df[time_col]
118
+ x_label = 'Time (ms)' if 'time' in time_col.lower() else time_col
119
+ else:
120
+ x_data = df.index
121
+ x_label = 'Sample Index'
122
+
123
+ colors = plt.cm.tab10(np.linspace(0, 1, len(leads)))
124
+
125
+ for idx, lead in enumerate(leads):
126
+ if lead in df.columns:
127
+ rolling_avg = df[lead].rolling(window=window, min_periods=1).mean()
128
+ ax.plot(x_data, rolling_avg, label=f'Lead {lead} (MA-{window})',
129
+ linewidth=1.5, alpha=0.8, color=colors[idx])
130
+
131
+ ax.set_xlabel(x_label, fontsize=11)
132
+ ax.set_ylabel('Amplitude (mV)', fontsize=11)
133
+ ax.set_title(f'Rolling Average (Window={window})', fontsize=13, fontweight='bold')
134
+ ax.legend(loc='upper right', fontsize=9, ncol=min(4, len(leads)))
135
+ ax.grid(True, alpha=0.3, linestyle='--')
136
+ plt.tight_layout()
137
+
138
+ return ECGAnalyzer._fig_to_base64(fig)
139
+
140
+ @staticmethod
141
+ def create_all_visualizations(df: pd.DataFrame, leads: List[str],
142
+ viz_types: List[str]) -> str:
143
+ """Create multiple visualizations based on selected types"""
144
+ html_parts = []
145
+ time_col = ECGAnalyzer.detect_time_column(df)
146
+
147
+ for viz_type in viz_types:
148
+ if viz_type == "Signal Waveform":
149
+ img_base64 = ECGAnalyzer.create_signal_plot(df, leads, time_col)
150
+ html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
151
+
152
+ elif viz_type == "Histogram":
153
+ img_base64 = ECGAnalyzer.create_histogram(df, leads)
154
+ html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
155
+
156
+ elif viz_type == "Scatter Plot":
157
+ # Use first two available leads for scatter plot
158
+ lead_x = leads[0] if len(leads) > 0 else 'I'
159
+ lead_y = leads[1] if len(leads) > 1 else 'II'
160
+ img_base64 = ECGAnalyzer.create_scatter_plot(df, lead_x, lead_y)
161
+ html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
162
+
163
+ elif viz_type == "Rolling Average":
164
+ img_base64 = ECGAnalyzer.create_rolling_average(df, leads, time_col)
165
+ html_parts.append(f'<div style="margin-bottom: 30px;"><img src="data:image/png;base64,{img_base64}" style="max-width:100%;"/></div>')
166
+
167
+ if not html_parts:
168
+ return '<div style="text-align:center; padding:40px;"><p>No visualizations selected</p></div>'
169
+
170
+ return '<div style="text-align:center;">' + ''.join(html_parts) + '</div>'
171
+
172
+ @staticmethod
173
+ def _fig_to_base64(fig) -> str:
174
+ """Convert matplotlib figure to base64 string"""
175
+ buf = io.BytesIO()
176
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
177
+ buf.seek(0)
178
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
179
+ plt.close(fig)
180
+ return img_base64
181
+
182
+ @staticmethod
183
+ def generate_statistics(df: pd.DataFrame, leads: List[str]) -> Dict[str, Any]:
184
+ """Generate statistical summary for selected leads"""
185
+ stats = {}
186
+ for lead in leads:
187
+ if lead in df.columns:
188
+ lead_data = df[lead].dropna()
189
+ stats[lead] = {
190
+ 'mean': float(lead_data.mean()),
191
+ 'std': float(lead_data.std()),
192
+ 'min': float(lead_data.min()),
193
+ 'max': float(lead_data.max()),
194
+ 'median': float(lead_data.median()),
195
+ 'q25': float(lead_data.quantile(0.25)),
196
+ 'q75': float(lead_data.quantile(0.75))
197
+ }
198
+ return stats
ecg_visualization.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import dash
3
+ from dash import html, dcc
4
+ from dash.dependencies import Input, Output
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ # Load the dataset
9
+ file_path = 'dataset/JS00001_filtered.csv' # Update this path as necessary
10
+ ecg_data = pd.read_csv(file_path)
11
+
12
+ # Define leads
13
+ leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
14
+
15
+ # Initialize the Dash app
16
+ app = dash.Dash(__name__)
17
+ app.layout = html.Div(children=[
18
+ html.H1(children='ECG Data Dashboard', style={'textAlign': 'center', 'color': '#007BFF'}),
19
+ html.Div([
20
+ html.Label('Select ECG Lead for Analysis:'),
21
+ dcc.Dropdown(
22
+ id='lead-dropdown',
23
+ options=[{'label': 'All Leads', 'value': 'ALL'}] + [{'label': lead, 'value': lead} for lead in leads],
24
+ value='ALL' # Default value to show all leads
25
+ )
26
+ ], style={'width': '30%', 'margin': '0 auto', 'padding': '20px'}),
27
+ dcc.Graph(id='ecg-data-visualization'),
28
+ ], style={'padding': '20px', 'maxWidth': '1200px', 'margin': '0 auto'})
29
+
30
+ # Define callback to update the figure based on the selected lead
31
+ @app.callback(
32
+ Output('ecg-data-visualization', 'figure'),
33
+ [Input('lead-dropdown', 'value')]
34
+ )
35
+ def update_figure(selected_lead):
36
+ # Initialize the figure with subplots
37
+ fig = make_subplots(rows=4, cols=1,
38
+ subplot_titles=("ECG Signal Over Time", "Histogram of Signal Amplitudes",
39
+ "Scatter Plot: Lead I vs Lead II", "Rolling Average"),
40
+ vertical_spacing=0.1,
41
+ specs=[[{"type": "scatter"}], [{"type": "histogram"}], [{"type": "scatter"}], [{"type": "scatter"}]])
42
+
43
+ # Conditionally display either all leads or just the selected lead
44
+ if selected_lead == 'ALL':
45
+ # Show all leads
46
+ for lead in leads:
47
+ fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data[lead], mode='lines', name=f'Lead {lead}'), row=1, col=1)
48
+ fig.add_trace(go.Histogram(x=ecg_data[lead], name=f'Lead {lead}', opacity=0.75), row=2, col=1)
49
+ else:
50
+ # Show only the selected lead
51
+ ecg_data['RollingAvg'] = ecg_data[selected_lead].rolling(window=100).mean()
52
+ fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data[selected_lead], mode='lines', name=f'Lead {selected_lead}'), row=1, col=1)
53
+ fig.add_trace(go.Histogram(x=ecg_data[selected_lead], name=f'Lead {selected_lead}', opacity=0.75), row=2, col=1)
54
+ fig.add_trace(go.Scatter(x=ecg_data['time'], y=ecg_data['RollingAvg'], mode='lines', name=f'Rolling Average: Lead {selected_lead}'), row=4, col=1)
55
+
56
+ # Common settings for all cases
57
+ fig.update_traces(opacity=0.75, bingroup=1, row=2, col=1)
58
+ fig.update_layout(barmode='overlay')
59
+ fig.add_trace(go.Scatter(x=ecg_data['I'], y=ecg_data['II'], mode='markers', name='Lead I vs Lead II'), row=3, col=1)
60
+
61
+ # Update the figure layout
62
+ fig.update_layout(height=1600, title_text="Comprehensive ECG Data Analysis", showlegend=True)
63
+ return fig
64
+
65
+ # Run the app
66
+ if __name__ == '__main__':
67
+ app.run_server(debug=True)
examples/main.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deduplication import find_near_duplicates
2
+ from featurizer import custom_featurizer
3
+ from issues import find_issues
4
+ from pipeline import make_step, run_pipeline
5
+ import pandas as pd
6
+ from tqdm.auto import tqdm
7
+
8
+ bar = tqdm(total=100, leave=True)
9
+
10
+ steps = [
11
+ make_step(find_near_duplicates, name="dedup")(progress=bar),
12
+ make_step(custom_featurizer, name="featurize")(
13
+ label=None, # optional; only used to drop NaN label rows
14
+ nan_strategy="impute",
15
+ on_pipeline_error="drop",
16
+ progress=bar
17
+ ),
18
+ make_step(find_issues, name="find_label_issues")(label="HARDSHIP_INDEX", progress=bar)
19
+ ]
20
+
21
+ df = pd.read_csv("./data/Lisette.csv")
22
+ results = run_pipeline(steps, df=df)
23
+
24
+ bar.close()
25
+ print(results)
26
+
27
+
pipeline/__init__.py ADDED
File without changes
pipeline/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (165 Bytes). View file
 
pipeline/__pycache__/deduplication.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
pipeline/__pycache__/featurizer.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
pipeline/__pycache__/issues.cpython-311.pyc ADDED
Binary file (7.78 kB). View file
 
pipeline/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (3.64 kB). View file
 
pipeline/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
pipeline/__pycache__/utils_cool.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
pipeline/deduplication.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from time import perf_counter
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from cleanlab import Datalab
9
+ from cleanlab.datalab.internal.issue_manager.duplicate import NearDuplicateIssueManager
10
+ from scipy.sparse import issparse
11
+ from scipy.special import comb
12
+ from sklearn.base import TransformerMixin
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from tqdm.auto import tqdm
15
+
16
+ from .utils_cool import PhaseProgress, _ensure_dense32, choose_k
17
+
18
+
19
+ # =========================
20
+ # Core near-duplicate finder
21
+ # =========================
22
+ def find_near_duplicates(
23
+ df: pd.DataFrame,
24
+ *,
25
+ # Cleanlab params
26
+ metric: str = "cosine",
27
+ threshold: float = 0.13,
28
+ k: Optional[int] = None,
29
+ # Vectorizer / features
30
+ vectorizer: Optional[TransformerMixin] = None,
31
+ force_dense: bool = False, # set True if your cleanlab version needs dense
32
+ # Behavior
33
+ verbose: bool = False,
34
+ # Progress: either a tqdm object or a callable phase,p in [0,1]
35
+ progress: Optional[Any] = None,
36
+ ) -> Tuple[Optional[pd.DataFrame], Dict]:
37
+ """
38
+ Detect near-duplicates using Cleanlab's NearDuplicateIssueManager.
39
+
40
+ Parameters
41
+ ----------
42
+ df : DataFrame
43
+ Rows to analyze. If no vectorizer is passed, all columns are joined as strings for TF-IDF.
44
+ metric : {"cosine", "euclidean", "manhattan"}
45
+ Distance metric for kNN graph.
46
+ threshold : float
47
+ Near-duplicate radius is based on threshold × median NN distance (internal to Cleanlab).
48
+ k : int or None
49
+ Neighborhood size. If None, uses sqrt(N) clipped to [5, 50] and ≤ N-1.
50
+ vectorizer : sklearn Transformer
51
+ Any transformer with fit_transform/transform. If None, uses TF-IDF (float32).
52
+ force_dense : bool
53
+ If True, densify features before passing to Cleanlab.
54
+ verbose : bool
55
+ Print timing breakdown.
56
+ progress : tqdm or Callable[[str, float], None]
57
+ Phase-aware progress reporting.
58
+
59
+ Returns
60
+ -------
61
+ (output_df, stats)
62
+ output_df : DataFrame after deduplication
63
+ stats : dict of counts/timings/params
64
+ """
65
+ # Progress setup (one bar per call unless user provided one)
66
+ local_bar = None
67
+ pp = None
68
+ if progress is None:
69
+ local_bar = tqdm(total=100, leave=True)
70
+ pp = PhaseProgress(local_bar, weights={"vectorize": .2, "find_issues": .7, "grouping": .1})
71
+ else:
72
+ if hasattr(progress, "set_description") and hasattr(progress, "update"):
73
+ # treat given tqdm as the bar
74
+ if not hasattr(progress, "_last_val"):
75
+ progress._last_val = 0
76
+ pp = PhaseProgress(progress, weights={"vectorize": .2, "find_issues": .7, "grouping": .1})
77
+
78
+ timings: Dict[str, float] = {}
79
+ t0 = perf_counter()
80
+ N = int(len(df))
81
+
82
+ # --- Vectorize ---
83
+ pp and pp.start("vectorize", extra={"N": N})
84
+ t_vec0 = perf_counter()
85
+ text_series = df.astype(str).agg(" ".join, axis=1)
86
+
87
+ if vectorizer is None:
88
+ vectorizer = TfidfVectorizer(dtype=np.float32)
89
+
90
+ if hasattr(vectorizer, "fit_transform"):
91
+ X = vectorizer.fit_transform(text_series.tolist())
92
+ elif hasattr(vectorizer, "transform"):
93
+ X = vectorizer.transform(text_series.tolist())
94
+ else:
95
+ raise TypeError("`vectorizer` must implement fit_transform or transform.")
96
+
97
+ if force_dense:
98
+ X = _ensure_dense32(X)
99
+
100
+ timings["vectorize"] = perf_counter() - t_vec0
101
+ pp and pp.tick_abs("vectorize", 1.0)
102
+ pp and pp.end("vectorize")
103
+
104
+ # --- Cleanlab duplicate finder ---
105
+ pp and pp.start("find_issues", extra={"metric": metric})
106
+ t_cl0 = perf_counter()
107
+
108
+ if k is None:
109
+ k = choose_k(N)
110
+
111
+ lab = Datalab(data={"__row__": list(range(N))})
112
+ ndm = NearDuplicateIssueManager(datalab=lab, metric=metric, threshold=threshold, k=k)
113
+
114
+ pp and pp.tick_abs("find_issues", 0.0, extra={"k": k})
115
+
116
+ try:
117
+ ndm.find_issues(features=X)
118
+ except Exception:
119
+ if issparse(X):
120
+ X_dense = _ensure_dense32(X)
121
+ ndm.find_issues(features=X_dense) # retry dense
122
+ else:
123
+ raise
124
+
125
+ pp and pp.tick_abs("find_issues", 1.0)
126
+ timings["cleanlab_find_issues"] = perf_counter() - t_cl0
127
+ pp and pp.end("find_issues", extra={"k": k})
128
+
129
+ near_dup_sets: List[List[int]] = getattr(ndm, "near_duplicate_sets", []) or []
130
+
131
+ # --- Representatives & output ---
132
+ pp and pp.start("grouping")
133
+ t_out0 = perf_counter()
134
+
135
+ # Keep smallest index as representative; skip any empty groups defensively
136
+ reps = [int(np.min(g)) for g in near_dup_sets if np.size(g) > 0]
137
+ in_any = {int(i) for g in near_dup_sets for i in np.asarray(g).ravel()}
138
+ keep_set = set(reps) | (set(range(N)) - in_any)
139
+
140
+ out_df = None
141
+ if N > 0:
142
+ keep_mask = np.zeros(N, dtype=bool)
143
+ if keep_set:
144
+ keep_mask[list(keep_set)] = True
145
+ out_df = df.iloc[np.where(keep_mask)[0]].copy()
146
+ else:
147
+ out_df = df.copy()
148
+
149
+ # Stats
150
+ n_groups = sum(1 for g in near_dup_sets if np.size(g) > 0)
151
+ group_sizes = [int(np.size(g)) for g in near_dup_sets if np.size(g) > 0]
152
+ n_flagged = sum(max(0, s - 1) for s in group_sizes) # rows we'd drop if keeping 1 per group
153
+ n_pairs = int(sum(comb(s, 2, exact=True) for s in group_sizes))
154
+
155
+ timings["groups_and_output"] = perf_counter() - t_out0
156
+ total_time = perf_counter() - t0
157
+
158
+ stats = {
159
+ "n_rows_before_dedup": N,
160
+ "n_near_dupe_pairs": n_pairs,
161
+ "n_groups": n_groups,
162
+ "avg_group_size": float(np.mean(group_sizes)) if group_sizes else 0.0,
163
+ "max_group_size": max(group_sizes) if group_sizes else 0,
164
+ "n_rows_flagged_duplicates": n_flagged,
165
+ "n_rows_after_dedup": int(len(out_df)) if out_df is not None else N,
166
+ "metric": metric,
167
+ "threshold": threshold,
168
+ "k": int(k),
169
+ "timings": timings,
170
+ "total_time_sec": total_time,
171
+ }
172
+
173
+ if verbose:
174
+ print(f"[timing] TOTAL: {total_time*1000:.1f} ms")
175
+ for k_, v in timings.items():
176
+ print(f" - {k_}: {v*1000:.1f} ms")
177
+
178
+ pp and pp.tick_abs("grouping", 1.0, extra={"groups": n_groups, "pairs": n_pairs})
179
+ pp and pp.end("grouping")
180
+
181
+ if local_bar is not None:
182
+ pp.close()
183
+
184
+ # Return the standardized pair expected by your pipeline
185
+ return out_df, stats
pipeline/featurizer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.impute import SimpleImputer
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.preprocessing import FunctionTransformer, StandardScaler
11
+
12
+ from .utils_cool import PhaseProgress
13
+
14
+
15
+ def _safe_apply_pipeline(
16
+ group_name: str,
17
+ df_in: pd.DataFrame,
18
+ cols: List[str],
19
+ pipeline: Pipeline,
20
+ *,
21
+ drop_on_error: bool,
22
+ warnings: List[str],
23
+ max_new_cols: int = 2000,
24
+ ) -> pd.DataFrame:
25
+ """
26
+ Apply a user pipeline to df_in[cols] safely.
27
+ - If success and returns DataFrame with same index: replace group columns with returned columns (prefixed).
28
+ - If success and returns ndarray/sparse: if width <= max_new_cols, expand with auto names; else drop group.
29
+ - If failure and drop_on_error: drop group columns and record a warning.
30
+ """
31
+ Xsub = df_in[cols]
32
+ try:
33
+ out = pipeline.fit_transform(Xsub)
34
+ except Exception as e:
35
+ if drop_on_error:
36
+ warnings.append(
37
+ f"[custom_featurizer] dropped '{group_name}' columns ({len(cols)}): {cols[:6]}{'...' if len(cols)>6 else ''} "
38
+ f"reason={type(e).__name__}: {str(e)[:180]}"
39
+ )
40
+ return df_in.drop(columns=cols)
41
+ raise
42
+
43
+ # If pipeline returns a DataFrame -> use its columns (prefixed)
44
+ if isinstance(out, pd.DataFrame):
45
+ out_df = out.copy()
46
+ # ensure index alignment
47
+ out_df.index = df_in.index
48
+ # prefix to avoid collisions
49
+ out_df.columns = [f"{group_name}__{c}" for c in out_df.columns]
50
+ df_out = df_in.drop(columns=cols).join(out_df)
51
+ return df_out
52
+
53
+ # If ndarray / sparse matrix -> expand to columns cautiously
54
+ try:
55
+ import numpy as _np
56
+ import scipy.sparse as _sp
57
+ if _sp.issparse(out):
58
+ out = out.toarray()
59
+ out = _np.asarray(out)
60
+ n_rows, n_cols = out.shape[0], (out.shape[1] if out.ndim == 2 else 1)
61
+ if n_rows != len(df_in):
62
+ raise ValueError(f"Pipeline for '{group_name}' returned {n_rows} rows; expected {len(df_in)}.")
63
+ if n_cols > max_new_cols:
64
+ warnings.append(
65
+ f"[custom_featurizer] '{group_name}' produced {n_cols} columns (> {max_new_cols}); dropping this group to avoid explosion."
66
+ )
67
+ return df_in.drop(columns=cols)
68
+ if out.ndim == 1:
69
+ out = out.reshape(-1, 1)
70
+ n_cols = 1
71
+ new_cols = [f"{group_name}__f{i}" for i in range(n_cols)]
72
+ out_df = pd.DataFrame(out, index=df_in.index, columns=new_cols)
73
+ df_out = df_in.drop(columns=cols).join(out_df)
74
+ return df_out
75
+ except Exception as e:
76
+ if drop_on_error:
77
+ warnings.append(
78
+ f"[custom_featurizer] failed to materialize output for '{group_name}'; dropping group. "
79
+ f"reason={type(e).__name__}: {str(e)[:180]}"
80
+ )
81
+ return df_in.drop(columns=cols)
82
+ raise
83
+
84
+
85
+ def custom_featurizer(
86
+ df: pd.DataFrame,
87
+ *,
88
+ # optional label just for cleaning rows with NaN in label; not returned as y
89
+ label: Optional[str] = None,
90
+
91
+ # user-provided pipelines (override defaults for that group)
92
+ numeric_pipeline: Optional[Pipeline] = None,
93
+ low_card_pipeline: Optional[Pipeline] = None,
94
+ text_pipeline: Optional[Pipeline] = None,
95
+
96
+ # defaults (used ONLY if the corresponding pipeline is NOT provided)
97
+ numeric_scale: bool = True, # scale numeric after impute
98
+ text_lowercase: bool = True, # lower+strip text
99
+ max_ohe_cardinality: int = 50, # threshold to split low-card vs text
100
+
101
+ # NaN handling for DEFAULTS ONLY (ignored if a custom pipeline is provided for that group)
102
+ nan_strategy: str = "impute", # "impute" or "drop"
103
+
104
+ # failure policy for user pipelines
105
+ on_pipeline_error: str = "drop", # "drop" -> drop group; "raise" -> bubble error
106
+
107
+ # control expansion when user pipelines return big matrices
108
+ max_new_cols_per_group: int = 2000,
109
+
110
+ # progress / logging
111
+ progress: Optional[Any] = None, # pass a tqdm here; None -> auto-create
112
+
113
+ # logging
114
+ verbose: bool = False,
115
+ ) -> Tuple[pd.DataFrame, Dict[str, Any]]:
116
+ """
117
+ Featurize a mixed DataFrame while keeping the output as a DataFrame for downstream steps.
118
+
119
+ Overview
120
+ --------
121
+ Columns are split into three groups:
122
+ • numeric (including boolean)
123
+ • low-cardinality categoricals (nunique ≤ `max_ohe_cardinality`)
124
+ • text/high-cardinality (nunique > `max_ohe_cardinality`)
125
+
126
+ You may pass custom sklearn `Pipeline`s per group. If a user pipeline raises and
127
+ `on_pipeline_error="drop"`, that entire group of columns is dropped and a short
128
+ warning is recorded in `stats["warnings"]` (the step does not crash). If `"raise"`,
129
+ the error is propagated.
130
+
131
+ Defaults when pipelines are NOT provided
132
+ ---------------------------------------
133
+ • Numeric (when `numeric_pipeline is None`)
134
+ - NaNs: if `nan_strategy="impute"`, apply `SimpleImputer(strategy="median")`;
135
+ if `"drop"`, rows with NaNs in numeric columns are dropped BEFORE this step.
136
+ - Scaling: if `numeric_scale=True`, apply `StandardScaler(with_mean=False)`.
137
+ - Columns: replaced in place (same column names remain; values become numeric/float).
138
+ • Low-cardinality categoricals (when `low_card_pipeline is None`)
139
+ - NaNs: if `nan_strategy="impute"`, apply `SimpleImputer(strategy="most_frequent")`;
140
+ if `"drop"`, rows with NaNs in these columns are dropped BEFORE this step.
141
+ - Encoding: **no one-hot by default** (values stay as cleaned strings/categories).
142
+ If you want encodings, pass your own `low_card_pipeline` (e.g., OneHotEncoder/CatBoostEncoder).
143
+ - Columns: preserved (same names).
144
+ • Text / high-cardinality (when `text_pipeline is None`)
145
+ - Build a tiny pipeline:
146
+ concat selected text cols → `TfidfVectorizer(dtype=float32, lowercase=text_lowercase)`
147
+ - Output: numeric TF-IDF features. Original text columns are **replaced** by new
148
+ columns named `txt__f0`, `txt__f1`, … (or `txt__<col>` if your pipeline returns a DataFrame).
149
+ - Feature explosion guard: if the produced matrix has more than `max_new_cols_per_group` columns,
150
+ the entire text group is dropped and a warning is recorded.
151
+ - If TF-IDF fails (e.g., empty vocabulary on tiny data), the text group is dropped with a warning.
152
+
153
+ NaN handling
154
+ ------------
155
+ `nan_strategy` applies only to groups using the **default** pipeline:
156
+ - "impute": impute NaNs (as above)
157
+ - "drop" : drop rows containing NaNs in any default-handled feature column **before**
158
+ transforming. For groups with a **custom** pipeline, NaN handling is your pipeline’s responsibility.
159
+
160
+ Parameters
161
+ ----------
162
+ df : pd.DataFrame
163
+ Input table. If `label` is provided, rows with NaN in `label` are dropped first.
164
+ label : Optional[str], default=None
165
+ Name of the target column (only used to drop NaN labels). Not returned as `y`.
166
+ numeric_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
167
+ Custom pipeline for numeric/boolean columns. If provided, `nan_strategy` is ignored for this group.
168
+ By default: median imputation (+ optional scaling) and columns are preserved.
169
+ low_card_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
170
+ Custom pipeline for low-card categorical columns. If provided, `nan_strategy` is ignored for this group.
171
+ By default: most-frequent imputation only; **no encoding**; columns are preserved.
172
+ text_pipeline : Optional[sklearn.pipeline.Pipeline], default=None
173
+ Custom pipeline for text/high-card columns. If provided, it replaces the text columns with whatever
174
+ it outputs (DataFrame → prefixed columns; array/sparse → `txt__f*` columns). If not provided, the
175
+ built-in concat+TF-IDF is used (see defaults above).
176
+ numeric_scale : bool, default=True
177
+ Applies `StandardScaler(with_mean=False)` to numeric columns in the default numeric pipeline.
178
+ text_lowercase : bool, default=True
179
+ Forwarded to the default `TfidfVectorizer(lowercase=...)`.
180
+ max_ohe_cardinality : int, default=50
181
+ Threshold to classify non-numeric columns as low-card (≤ threshold) vs text/high-card (> threshold).
182
+ nan_strategy : {"impute","drop"}, default="impute"
183
+ Strategy for **default** pipelines only (custom pipelines manage their own NaNs).
184
+ on_pipeline_error : {"drop","raise"}, default="drop"
185
+ If a **user** pipeline raises: "drop" → drop that group and record a warning; "raise" → propagate.
186
+ max_new_cols_per_group : int, default=2000
187
+ Upper bound on the number of columns a group is allowed to add (applies to array/sparse outputs).
188
+ If exceeded, the group is dropped and a warning is recorded.
189
+ progress : Optional[tqdm], default=None
190
+ Phase-aware progress bar (clean → split → numeric → low_card → text → finalize).
191
+ If None, a local bar is created and closed automatically.
192
+ verbose : bool, default=False
193
+ When True, prints recorded warnings (in addition to returning them in `stats["warnings"]`).
194
+
195
+ Returns
196
+ -------
197
+ output_df : pd.DataFrame
198
+ Transformed DataFrame ready for the next pipeline step. Numeric/low-card columns are
199
+ updated in place by defaults; text columns are replaced by TF-IDF features when using
200
+ the default text path.
201
+ stats : Dict[str, Any]
202
+ Minimal metadata:
203
+ - "warnings": List[str]
204
+ - "cols": {"numeric": [...], "low_card": [...], "text": [...]}
205
+ - "n_rows_before": int
206
+ - "n_rows_after": int
207
+
208
+ Notes
209
+ -----
210
+ • If a user pipeline returns a DataFrame, columns are prefixed with the group name to avoid collisions.
211
+ If it returns an array/sparse matrix, columns are auto-named `"{group}__f{i}"`.
212
+ • The feature explosion guard applies after transformation; if you often hit it with text, consider passing
213
+ your own `TfidfVectorizer(max_features=...)` in `text_pipeline` to cap the width proactively.
214
+ """
215
+ # progress setup
216
+ local_bar = None
217
+ pp = None
218
+ if progress is None:
219
+ from tqdm.auto import tqdm # local import to keep module lightweight
220
+ local_bar = tqdm(total=100, leave=True)
221
+ pp = PhaseProgress(local_bar, weights={
222
+ "clean": .10, "split": .10, "numeric": .30, "low_card": .25, "text": .20, "finalize": .05
223
+ })
224
+ elif hasattr(progress, "set_description") and hasattr(progress, "update"):
225
+ if not hasattr(progress, "_last_val"):
226
+ progress._last_val = 0
227
+ pp = PhaseProgress(progress, weights={
228
+ "clean": .10, "split": .10, "numeric": .30, "low_card": .25, "text": .20, "finalize": .05
229
+ })
230
+
231
+ warnings: List[str] = []
232
+
233
+ # 1) drop rows with NaN label if present
234
+ n_before = len(df)
235
+ pp and pp.start("clean", extra={"N": n_before})
236
+ if label is not None and label not in df.columns:
237
+ # ensure the bar is closed even on error
238
+ if local_bar is not None:
239
+ pp.close()
240
+ raise KeyError(f"Target column '{label}' not in DataFrame.")
241
+ if label is not None:
242
+ df = df.dropna(subset=[label]).reset_index(drop=True)
243
+ pp and pp.tick_abs("clean", 1.0, extra={"N": len(df)})
244
+ pp and pp.end("clean")
245
+
246
+ # 2) split columns
247
+ pp and pp.start("split")
248
+ Xf = df if label is None else df.drop(columns=[label])
249
+ numeric_cols: List[str] = Xf.select_dtypes(include=[np.number, "boolean"]).columns.tolist()
250
+ non_numeric = [c for c in Xf.columns if c not in numeric_cols]
251
+
252
+ low_card_cols: List[str] = []
253
+ text_cols: List[str] = []
254
+ for c in non_numeric:
255
+ nunq = Xf[c].nunique(dropna=True)
256
+ (low_card_cols if nunq <= max_ohe_cardinality else text_cols).append(c)
257
+ pp and pp.tick_abs("split", 1.0, extra={
258
+ "num": len(numeric_cols), "low": len(low_card_cols), "txt": len(text_cols)
259
+ })
260
+ pp and pp.end("split")
261
+
262
+ # 3) NaN strategy for default groups (custom pipelines assumed to handle NaNs)
263
+ if nan_strategy not in {"impute", "drop"}:
264
+ if local_bar is not None:
265
+ pp.close()
266
+ raise ValueError("nan_strategy must be 'impute' or 'drop'")
267
+ if nan_strategy == "drop":
268
+ cols_to_check = []
269
+ if numeric_cols and numeric_pipeline is None:
270
+ cols_to_check += numeric_cols
271
+ if low_card_cols and low_card_pipeline is None:
272
+ cols_to_check += low_card_cols
273
+ if text_cols and text_pipeline is None:
274
+ cols_to_check += text_cols
275
+ if cols_to_check:
276
+ mask = ~Xf[cols_to_check].isna().any(axis=1)
277
+ df = df.loc[mask].reset_index(drop=True)
278
+ Xf = df if label is None else df.drop(columns=[label])
279
+
280
+ # 4) numeric
281
+ pp and pp.start("numeric", extra={"cols": len(numeric_cols)})
282
+ if numeric_cols:
283
+ if numeric_pipeline is not None:
284
+ df = _safe_apply_pipeline(
285
+ "num", df, numeric_cols, numeric_pipeline,
286
+ drop_on_error=(on_pipeline_error == "drop"),
287
+ warnings=warnings,
288
+ max_new_cols=max_new_cols_per_group,
289
+ )
290
+ else:
291
+ if nan_strategy == "impute":
292
+ imputed = SimpleImputer(strategy="median").fit_transform(df[numeric_cols])
293
+ df.loc[:, numeric_cols] = imputed
294
+ if numeric_scale:
295
+ scaler = StandardScaler(with_mean=False)
296
+ scaled = scaler.fit_transform(df[numeric_cols].astype(float, copy=False))
297
+ df.loc[:, numeric_cols] = np.asarray(scaled)
298
+ pp and pp.tick_abs("numeric", 1.0)
299
+ pp and pp.end("numeric", extra={"cols": len(numeric_cols)})
300
+
301
+ # 5) low-card categoricals
302
+ pp and pp.start("low_card", extra={"cols": len(low_card_cols)})
303
+ if low_card_cols:
304
+ if low_card_pipeline is not None:
305
+ df = _safe_apply_pipeline(
306
+ "cat", df, low_card_cols, low_card_pipeline,
307
+ drop_on_error=(on_pipeline_error == "drop"),
308
+ warnings=warnings,
309
+ max_new_cols=max_new_cols_per_group,
310
+ )
311
+ else:
312
+ if nan_strategy == "impute":
313
+ df.loc[:, low_card_cols] = SimpleImputer(strategy="most_frequent").fit_transform(df[low_card_cols])
314
+ pp and pp.tick_abs("low_card", 1.0)
315
+ pp and pp.end("low_card", extra={"cols": len(low_card_cols)})
316
+
317
+ # 6) text
318
+ pp and pp.start("text", extra={"cols": len(text_cols)})
319
+ if text_cols:
320
+ if text_pipeline is None:
321
+ def _concat_cols(Xframe: pd.DataFrame):
322
+ return Xframe.fillna("").astype(str).agg(" ".join, axis=1).values
323
+
324
+ text_pipeline = Pipeline([
325
+ ("concat", FunctionTransformer(_concat_cols, validate=False)),
326
+ ("tfidf", TfidfVectorizer(
327
+ dtype=np.float32,
328
+ lowercase=bool(text_lowercase),
329
+ )),
330
+ ])
331
+
332
+ df = _safe_apply_pipeline(
333
+ "txt", df, text_cols, text_pipeline,
334
+ drop_on_error=(on_pipeline_error == "drop"),
335
+ warnings=warnings,
336
+ max_new_cols=max_new_cols_per_group,
337
+ )
338
+
339
+ pp and pp.tick_abs("text", 1.0)
340
+ pp and pp.end("text", extra={"cols": len(text_cols)})
341
+
342
+ # 7) finalize
343
+ pp and pp.start("finalize")
344
+ stats: Dict[str, Any] = {
345
+ "warnings": warnings,
346
+ "cols": {"numeric": numeric_cols, "low_card": low_card_cols, "text": text_cols},
347
+ "n_rows_before": n_before,
348
+ "n_rows_after": len(df),
349
+ }
350
+ pp and pp.tick_abs("finalize", 1.0, extra={"warn": len(warnings)})
351
+ pp and pp.end("finalize", extra={"warn": len(warnings)})
352
+
353
+ if verbose and warnings:
354
+ print("\n".join(warnings))
355
+ if local_bar is not None:
356
+ pp.close()
357
+
358
+ return df, stats
pipeline/issues.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.sparse import issparse
8
+ from sklearn.linear_model import LogisticRegression, SGDRegressor
9
+ from tqdm.auto import tqdm
10
+
11
+ from .utils_cool import PhaseProgress, _ensure_dense32, _infer_task
12
+
13
+
14
+ def find_issues(
15
+ df: pd.DataFrame,
16
+ *,
17
+ label: str,
18
+ task: Optional[str] = None, # "classification" | "regression"; if None we infer
19
+ model: Optional[Any] = None, # sklearn estimator; default chosen by task
20
+ progress: Optional[Any] = None, # tqdm bar; None -> auto-create
21
+ verbose: bool = False,
22
+ ) -> Tuple[Optional[pd.DataFrame], Dict]:
23
+ """
24
+ Detect label issues using Cleanlab's CleanLearning.
25
+
26
+ Parameters
27
+ ----------
28
+ df : DataFrame
29
+ Input table containing features and a label column.
30
+ label : str
31
+ Name of the label column. Rows with NaN in label are dropped.
32
+ task : {"classification","regression"}, optional
33
+ If not provided, inferred from y.
34
+ model : sklearn estimator, optional
35
+ If not provided, defaults to LogisticRegression or SGDRegressor.
36
+ progress : tqdm, optional
37
+ Phase-aware progress bar. If None, a local bar is created and closed.
38
+ verbose : bool, default False
39
+ Print warnings/timings in addition to returning them in stats.
40
+
41
+ Returns
42
+ -------
43
+ (df_out, stats)
44
+ df_out : DataFrame with label-issues optionally removed (or original df if `remove_issues=False`).
45
+ stats : dict with minimal metadata and counts.
46
+ """
47
+ # ---- progress setup ----
48
+ local_bar = None
49
+ pp = None
50
+ if progress is None:
51
+ local_bar = tqdm(total=100, leave=True)
52
+ pp = PhaseProgress(local_bar, weights={"clean": .15, "cleanlab": .75, "finalize": .10})
53
+ elif hasattr(progress, "set_description") and hasattr(progress, "update"):
54
+ if not hasattr(progress, "_last_val"):
55
+ progress._last_val = 0
56
+ pp = PhaseProgress(progress, weights={"clean": .15, "cleanlab": .75, "finalize": .10})
57
+
58
+ warnings: List[str] = []
59
+ n_before = len(df)
60
+
61
+ # ---- clean: ensure label exists, drop NaNs in label ----
62
+ pp and pp.start("clean", extra={"N": n_before})
63
+ if label not in df.columns:
64
+ if local_bar is not None:
65
+ pp.close()
66
+ raise KeyError(f"Label column '{label}' not found in DataFrame.")
67
+ df_in = df.dropna(subset=[label]).reset_index(drop=True)
68
+ X = df_in
69
+ y_raw = X[label].to_numpy()
70
+ pp and pp.tick_abs("clean", 1.0, extra={"N": len(X)})
71
+ pp and pp.end("clean")
72
+
73
+ # ---- task/model selection ----
74
+ task_applied = _infer_task(y_raw, task)
75
+ if model is None:
76
+ if task_applied == "classification":
77
+ model = LogisticRegression(solver="saga", n_jobs=-1)
78
+ else:
79
+ model = SGDRegressor()
80
+
81
+ # y encoding for classification (Cleanlab expects numeric labels)
82
+ if task_applied == "classification":
83
+ classes, y = np.unique(y_raw, return_inverse=True)
84
+ else:
85
+ classes = None
86
+ y = y_raw.astype(np.float64, copy=False)
87
+
88
+ # ---- cleanlab: find label issues (with auto-dense fallback for tiny/sparse edge cases) ----
89
+ pp and pp.start("cleanlab", extra={"task": task_applied})
90
+ used_dense_fallback = False
91
+ try:
92
+ if task_applied == "classification":
93
+ from cleanlab.classification import CleanLearning as _CL
94
+ else:
95
+ from cleanlab.regression.learn import CleanLearning as _CL
96
+ cl = _CL(model)
97
+ issues = cl.find_label_issues(X, y)
98
+ except Exception as e:
99
+ if issparse(X):
100
+ # retry dense (cleanlab/small-N often prefers dense)
101
+ Xd = _ensure_dense32(X)
102
+ try:
103
+ cl = _CL(model)
104
+ issues = cl.find_label_issues(Xd, y)
105
+ used_dense_fallback = True
106
+ except Exception as e2:
107
+ if local_bar is not None:
108
+ pp.close()
109
+ raise RuntimeError(f"Cleanlab failed on sparse and dense features: {e2}") from e
110
+ else:
111
+ if local_bar is not None:
112
+ pp.close()
113
+ raise
114
+
115
+ # Parse outputs robustly
116
+ if isinstance(issues, pd.DataFrame):
117
+ is_issue = issues.get("is_label_issue", None)
118
+ label_quality = issues.get("label_quality", None)
119
+ else:
120
+ is_issue = None
121
+ label_quality = None
122
+
123
+ n = len(issues) if hasattr(issues, "__len__") else len(y)
124
+ n_issues = int(is_issue.sum()) if isinstance(is_issue, (pd.Series, np.ndarray)) else 0
125
+ pct = round((n_issues / n) * 100.0, 3) if n else 0.0
126
+ avg_quality = float(np.nanmean(label_quality.values)) if isinstance(label_quality, pd.Series) else float("nan")
127
+
128
+ pp and pp.tick_abs("cleanlab", 1.0, extra={"issues": n_issues})
129
+ pp and pp.end("cleanlab", extra={"issues": n_issues})
130
+
131
+ # ---- finalize: optionally drop issue rows ----
132
+ pp and pp.start("finalize")
133
+ df_out = df_in
134
+ if isinstance(is_issue, (pd.Series, np.ndarray)):
135
+ mask_keep = ~(is_issue.astype(bool).values)
136
+ df_out = df_in.loc[mask_keep].copy()
137
+
138
+ stats: Dict[str, Any] = {
139
+ "n_rows_before_cleanlab": int(len(df_in)),
140
+ "n_label_issues": int(n_issues),
141
+ "pct_label_issues": float(pct),
142
+ "avg_label_quality": float(avg_quality),
143
+ "n_rows_after_cleanlab": int(len(df_out)),
144
+ "task_applied": task_applied,
145
+ "model_name": type(model).__name__,
146
+ "used_dense_fallback": bool(used_dense_fallback),
147
+ "warnings": warnings,
148
+ }
149
+ pp and pp.tick_abs("finalize", 1.0)
150
+ pp and pp.end("finalize")
151
+
152
+ if verbose and warnings:
153
+ print("\n".join(warnings))
154
+ if local_bar is not None:
155
+ pp.close()
156
+
157
+ return df_out, stats
pipeline/pipeline.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Callable, Optional, Any, Dict, Tuple, Sequence, List
3
+
4
+ import inspect
5
+ from typing import Optional, Callable, Tuple, Dict, Any
6
+ import pandas as pd
7
+
8
+ def make_step(func: Callable, *, name: Optional[str] = None, use_original_df: bool = False):
9
+ """
10
+ Wrap `func` into a pipeline step builder.
11
+ Assumes: func(...) -> (output_df, stats: dict)
12
+
13
+ If the function has a `df` parameter, the step will:
14
+ - by default use the previous step's output as df
15
+ - if `use_original_df=True`, use the original df passed to `run_pipeline`
16
+ - if you bind `df=` at build time, that takes precedence over both
17
+ """
18
+ sig = inspect.signature(func)
19
+ step_name = name or func.__name__
20
+
21
+ def builder(**params):
22
+ sig.bind_partial(**params) # validate early
23
+
24
+ def _run(prev_df: pd.DataFrame, orig_df: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, Any], str]:
25
+ call = dict(params)
26
+
27
+ if "df" in sig.parameters:
28
+ # explicit df bound at build time wins
29
+ if "df" not in call:
30
+ call["df"] = orig_df if use_original_df and orig_df is not None else prev_df
31
+
32
+ out_df, stats = func(**call)
33
+ if not isinstance(stats, dict):
34
+ raise TypeError(f"{step_name}: expected stats to be dict, got {type(stats)}")
35
+
36
+ # for logging: the input df actually used by the function (if any), else prev_df
37
+ input_df = call.get("df", prev_df)
38
+ return input_df, out_df, stats, step_name
39
+
40
+ _run.__name__ = step_name
41
+ _run.__doc__ = func.__doc__
42
+ _run.__signature__ = sig
43
+ return _run
44
+
45
+ builder.__name__ = f"{step_name}_step"
46
+ builder.__doc__ = f"Step builder for `{func.__name__}`.\n\n" + (func.__doc__ or "")
47
+ builder.__signature__ = sig
48
+ return builder
49
+
50
+ def run_pipeline(steps: Sequence[Callable], df: pd.DataFrame):
51
+ """
52
+ Calls each step as step(prev_df, orig_df) and chains outputs.
53
+ Returns final_df, logs.
54
+ """
55
+ orig_df = df
56
+ prev_df = df
57
+ logs: List[Dict[str, Any]] = []
58
+ for step in steps:
59
+ in_df, out_df, stats, name = step(prev_df, orig_df)
60
+ logs.append({"step": name, **stats})
61
+ prev_df = out_df if out_df is not None else prev_df
62
+ return prev_df, logs
63
+
pipeline/utils_cool.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import re
4
+ from dataclasses import dataclass
5
+ from time import perf_counter
6
+ from typing import Any, Dict, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from scipy.sparse import issparse
11
+ from tqdm import tqdm
12
+
13
+
14
+ @dataclass
15
+ class PhaseProgress:
16
+ bar: "tqdm"
17
+ weights: Dict[str, float]
18
+ total: int = 100
19
+
20
+ def __post_init__(self):
21
+ self._norm = sum(self.weights.values()) or 1.0
22
+ self._done = 0.0
23
+ self._phase = None
24
+ self._phase_t0 = None
25
+ # for smooth updates
26
+ if not hasattr(self.bar, "_last_val"):
27
+ self.bar._last_val = 0
28
+
29
+ def start(self, phase: str, extra: Optional[Dict] = None):
30
+ self._phase = phase
31
+ self._phase_t0 = perf_counter()
32
+ self.bar.set_description_str(phase)
33
+ if extra:
34
+ self.bar.set_postfix(extra, refresh=False)
35
+
36
+ def tick_abs(self, phase: str, p01: float, extra: Optional[Dict] = None):
37
+ """Update absolute progress based on within-phase progress p01 ∈ [0,1]."""
38
+ p01 = max(0.0, min(1.0, float(p01)))
39
+ w = self.weights.get(phase, 0.0) / self._norm
40
+ target = int(round(self.total * (self._done + w * p01)))
41
+ delta = target - self.bar._last_val
42
+ if delta > 0:
43
+ self.bar.update(delta)
44
+ self.bar._last_val = target
45
+ self.bar.set_description_str(f"{phase} {int(100*p01)}%")
46
+ if extra:
47
+ self.bar.set_postfix(extra, refresh=False)
48
+
49
+ def end(self, phase: str, extra: Optional[Dict] = None):
50
+ w = self.weights.get(phase, 0.0) / self._norm
51
+ self._done += w
52
+ elapsed_ms = (perf_counter() - (self._phase_t0 or perf_counter())) * 1000
53
+ post = dict(extra or {})
54
+ post["t"] = f"{elapsed_ms:.0f}ms"
55
+ self.bar.set_postfix(post, refresh=False)
56
+
57
+ def close(self):
58
+ try:
59
+ if self.bar._last_val < self.total:
60
+ self.bar.update(self.total - self.bar._last_val)
61
+ finally:
62
+ self.bar.close()
63
+
64
+ def choose_k(N: int, k_min: int = 5, k_max: int = 50) -> int:
65
+ """sqrt(N) clipped to [k_min, k_max] and ≤ N-1."""
66
+ if N <= 1:
67
+ return 1
68
+ k = int(math.sqrt(N))
69
+ k = max(k_min, min(k, k_max))
70
+ return min(k, N - 1)
71
+
72
+ def _ensure_dense32(X) -> np.ndarray:
73
+ """Convert to contiguous float32 ndarray (densify only if needed)."""
74
+ if issparse(X):
75
+ X = X.toarray()
76
+ return np.asarray(X, dtype=np.float32, order="C")
77
+
78
+ def decide_task_and_model(
79
+ y: np.ndarray,
80
+ series: pd.Series,
81
+ *,
82
+ is_categorical: bool = False,
83
+ few_class_floor: int = 20,
84
+ few_class_frac: float = 0.05,
85
+ ):
86
+ N = len(y)
87
+
88
+ # dtype checks
89
+ is_bool = pd.api.types.is_bool_dtype(series)
90
+ is_numeric = pd.api.types.is_numeric_dtype(series)
91
+
92
+ # unique values (ignore NaNs)
93
+ y_nonnull = y[~pd.isnull(y)]
94
+ n_unique = len(pd.unique(y_nonnull))
95
+
96
+ # numeric-but-few-classes heuristic
97
+ few_classes_threshold = max(few_class_floor, int(np.ceil(few_class_frac * max(N, 1))))
98
+ numeric_few_classes = is_numeric and (n_unique <= few_classes_threshold)
99
+
100
+ use_classification = (
101
+ is_categorical
102
+ or is_bool
103
+ or (not is_numeric)
104
+ or numeric_few_classes
105
+ )
106
+
107
+ if use_classification:
108
+ return "classification"
109
+ else:
110
+ return "regression"
111
+
112
+ def _infer_task(y: np.ndarray, task: Optional[str]) -> str:
113
+ """Decide task if not provided: numeric with many uniques -> regression, else classification."""
114
+ if task in {"classification", "regression"}:
115
+ return task
116
+
117
+ if np.issubdtype(y.dtype, np.number):
118
+ nunq = len(np.unique(y[~pd.isna(y)]))
119
+ is_categorical = nunq <= max(2, int(0.02 * max(1, len(y))))
120
+ else:
121
+ is_categorical = True
122
+ return "classification" if is_categorical else "regression"
123
+
124
+
125
+ # --------- DataFrame payload helpers (for tool IO) ---------
126
+
127
+ def df_to_payload(df: pd.DataFrame) -> Dict[str, Any]:
128
+ return {"orient": "split", "data": df.to_dict(orient="split")}
129
+
130
+
131
+ def df_from_payload(p: Dict[str, Any]) -> pd.DataFrame:
132
+ d = p["data"]
133
+ return pd.DataFrame(d["data"], columns=d["columns"])
134
+
135
+ # --------- Light heuristics for task/label guess ---------
136
+
137
+ def guess_task_and_label(df: pd.DataFrame) -> Dict[str, Any]:
138
+ cols = list(df.columns)
139
+ label_candidates = [c for c in cols if c.lower() in {"label","target","y","class","outcome"}]
140
+ label = label_candidates[0] if label_candidates else None
141
+
142
+
143
+ task = None
144
+ if label and (pd.api.types.is_integer_dtype(df[label]) or pd.api.types.is_bool_dtype(df[label])):
145
+ nuniq = df[label].nunique(dropna=True)
146
+ task = "classification" if nuniq <= max(20, int(0.05*len(df))) else "regression"
147
+ elif label and pd.api.types.is_float_dtype(df[label]):
148
+ task = "regression"
149
+ else:
150
+ task = "unsupervised"
151
+
152
+
153
+ issues = []
154
+ if label and df[label].isna().any():
155
+ issues.append(f"Missing values in label `{label}`")
156
+ if label and df[label].nunique() == 1:
157
+ issues.append(f"Label `{label}` has a single class")
158
+
159
+
160
+ return {
161
+ "columns": cols,
162
+ "dtypes": {c: str(df[c].dtype) for c in cols},
163
+ "label_guess": label,
164
+ "task_guess": task,
165
+ "issues": issues,
166
+ "shape": df.shape,
167
+ }
168
+
169
+ # --------- Signature extraction for asking params ---------
170
+
171
+
172
+ def get_signature_dict(fn) -> Dict[str, Any]:
173
+ sig = inspect.signature(fn)
174
+ doc = (fn.__doc__ or "").strip()
175
+ params = []
176
+ for p in sig.parameters.values():
177
+ if p.name == "df":
178
+ continue
179
+ default = None if (p.default is inspect._empty) else p.default
180
+ annotation = None if (p.annotation is inspect._empty) else str(p.annotation)
181
+ params.append({"name": p.name, "default": default, "annotation": annotation, "kind": str(p.kind)})
182
+ return {"params": params, "doc": doc}
183
+
184
+ # --------- Parse free-text confirmation like "Run dedup threshold=0.93 metric=cosine" ---------
185
+ STEP_ALIASES = {
186
+ "dedup": {"dedup","de-dup","duplicates","near-dup"},
187
+ "featurize": {"featurize","features","featureize","engineering"},
188
+ "find_label_issues": {"find_label_issues","label issues","cleanlab","label noise"},
189
+ }
190
+
191
+
192
+ def parse_user_choice(text: str) -> Tuple[Optional[str], Dict[str, Any]]:
193
+ t = text.lower()
194
+ chosen = None
195
+ for step, aliases in STEP_ALIASES.items():
196
+ if any(a in t for a in aliases):
197
+ chosen = step
198
+ break
199
+
200
+ params: Dict[str, Any] = {}
201
+ for m in re.finditer(r"(\w+)\s*=\s*([\-\w\.]+)", text):
202
+ k, v = m.group(1), m.group(2)
203
+ if v.replace('.', '', 1).isdigit():
204
+ v = float(v) if '.' in v else int(v)
205
+ elif v.lower() in {"true","false"}:
206
+ v = (v.lower() == "true")
207
+ params[k] = v
208
+ return chosen, params
requirements.txt ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.15
4
+ aiosignal==1.4.0
5
+ annotated-types==0.7.0
6
+ anyio==4.11.0
7
+ async-timeout==4.0.3
8
+ attrs==25.3.0
9
+ cachetools==6.2.0
10
+ certifi==2025.8.3
11
+ charset-normalizer==3.4.3
12
+ cleanlab==2.7.1
13
+ click==8.1.8
14
+ contourpy==1.3.0
15
+ cycler==0.12.1
16
+ dataclasses-json==0.6.7
17
+ datasets==4.1.1
18
+ dill==0.4.0
19
+ dotenv==0.9.9
20
+ exceptiongroup==1.3.0
21
+ fastapi==0.118.0
22
+ ffmpy==0.6.1
23
+ filelock==3.19.1
24
+ filetype==1.2.0
25
+ fonttools==4.60.1
26
+ frozenlist==1.7.0
27
+ fsspec==2025.9.0
28
+ google-ai-generativelanguage==0.7.0
29
+ google-api-core==2.25.2
30
+ google-auth==2.41.1
31
+ googleapis-common-protos==1.70.0
32
+ gradio==4.44.1
33
+ gradio_client==1.3.0
34
+ grpcio==1.75.1
35
+ grpcio-status==1.75.1
36
+ h11==0.16.0
37
+ hf-xet==1.1.10
38
+ httpcore==1.0.9
39
+ httpx==0.28.1
40
+ httpx-sse==0.4.1
41
+ huggingface-hub==0.35.3
42
+ idna==3.10
43
+ importlib_resources==6.5.2
44
+ Jinja2==3.1.6
45
+ joblib==1.5.2
46
+ jsonpatch==1.33
47
+ jsonpointer==3.0.0
48
+ kiwisolver==1.4.7
49
+ langchain==0.3.27
50
+ langchain-community==0.3.30
51
+ langchain-core==0.3.78
52
+ langchain-google-genai==2.1.12
53
+ langchain-text-splitters==0.3.11
54
+ langgraph==0.6.8
55
+ langgraph-checkpoint==2.1.1
56
+ langgraph-prebuilt==0.6.4
57
+ langgraph-sdk==0.2.9
58
+ langsmith==0.4.32
59
+ markdown-it-py==3.0.0
60
+ MarkupSafe==2.1.5
61
+ marshmallow==3.26.1
62
+ matplotlib==3.9.4
63
+ mdurl==0.1.2
64
+ multidict==6.6.4
65
+ multiprocess==0.70.16
66
+ mypy_extensions==1.1.0
67
+ narwhals==2.6.0
68
+ numpy==1.26.4
69
+ orjson==3.11.3
70
+ ormsgpack==1.10.0
71
+ packaging==25.0
72
+ pandas==2.3.3
73
+ pandoc==2.4
74
+ pillow==10.4.0
75
+ plotly==6.3.1
76
+ plumbum==1.9.0
77
+ ply==3.11
78
+ propcache==0.4.0
79
+ proto-plus==1.26.1
80
+ protobuf==6.32.1
81
+ pyarrow==21.0.0
82
+ pyasn1==0.6.1
83
+ pyasn1_modules==0.4.2
84
+ pydantic==2.10.6
85
+ pydantic-settings==2.11.0
86
+ pydantic_core==2.27.2
87
+ pydub==0.25.1
88
+ Pygments==2.19.2
89
+ pypandoc==1.15
90
+ pyparsing==3.2.5
91
+ python-dateutil==2.9.0.post0
92
+ python-dotenv==1.1.1
93
+ python-multipart==0.0.20
94
+ pytz==2025.2
95
+ PyYAML==6.0.3
96
+ reportlab==4.4.4
97
+ requests==2.32.5
98
+ requests-toolbelt==1.0.0
99
+ rich==14.1.0
100
+ rsa==4.9.1
101
+ ruff==0.13.3
102
+ scikit-learn==1.6.1
103
+ scipy==1.13.1
104
+ semantic-version==2.10.0
105
+ shellingham==1.5.4
106
+ six==1.17.0
107
+ sniffio==1.3.1
108
+ SQLAlchemy==2.0.43
109
+ starlette==0.48.0
110
+ tenacity==9.1.2
111
+ termcolor==3.1.0
112
+ threadpoolctl==3.6.0
113
+ tomlkit==0.12.0
114
+ tqdm==4.67.1
115
+ typer==0.19.2
116
+ typing-inspect==0.9.0
117
+ typing-inspection==0.4.2
118
+ typing_extensions==4.15.0
119
+ tzdata==2025.2
120
+ urllib3==2.5.0
121
+ uvicorn==0.37.0
122
+ websockets==12.0
123
+ xxhash==3.6.0
124
+ yarl==1.20.1
125
+ zipp==3.23.0
126
+ zstandard==0.25.0