|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from typing import Optional, Tuple, List, Dict, Any |
|
|
import io |
|
|
import base64 |
|
|
from tqdm.auto import tqdm |
|
|
from dataclasses import dataclass |
|
|
import gradio as gr |
|
|
import json |
|
|
|
|
|
from pipeline.deduplication import find_near_duplicates |
|
|
from pipeline.featurizer import custom_featurizer |
|
|
from pipeline.issues import find_issues |
|
|
from pipeline.pipeline import make_step, run_pipeline |
|
|
|
|
|
from ecg_analyzer import ECGAnalyzer |
|
|
from agent.simple_chat import simple_chat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TaskConfig: |
|
|
"""Configuration for each analysis task""" |
|
|
name: str |
|
|
data_type: str |
|
|
requires_params: bool |
|
|
param_components: List[Dict[str, Any]] |
|
|
output_tabs: List[str] |
|
|
|
|
|
class TaskRegistry: |
|
|
"""Registry mapping tasks to their configurations""" |
|
|
|
|
|
@staticmethod |
|
|
def get_config(data_type: str, task_name: str) -> Optional[TaskConfig]: |
|
|
"""Get configuration for a specific task""" |
|
|
configs = { |
|
|
"EHR Data": { |
|
|
"Near-Duplicate Detection": TaskConfig( |
|
|
name="Near-Duplicate Detection", |
|
|
data_type="EHR Data", |
|
|
requires_params=True, |
|
|
param_components=[ |
|
|
{"type": "dropdown", "label": "Label Column", "elem_id": "ndd_label"} |
|
|
], |
|
|
output_tabs=["original", "processed", "summary"] |
|
|
), |
|
|
"Find Mislabeled Data": TaskConfig( |
|
|
name="Find Mislabeled Data", |
|
|
data_type="EHR Data", |
|
|
requires_params=True, |
|
|
param_components=[ |
|
|
{"type": "dropdown", "label": "Label Column", "elem_id": "mislabel_label"} |
|
|
], |
|
|
output_tabs=["original", "summary"] |
|
|
) |
|
|
}, |
|
|
"ECG Data": { |
|
|
"ECG Visualization": TaskConfig( |
|
|
name="ECG Visualization", |
|
|
data_type="ECG Data", |
|
|
requires_params=True, |
|
|
param_components=[ |
|
|
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_leads"}, |
|
|
{"type": "checkboxgroup", "label": "Visualization Types", "elem_id": "ecg_viz_types"} |
|
|
], |
|
|
output_tabs=["visualization", "summary"] |
|
|
), |
|
|
"Statistical Summary": TaskConfig( |
|
|
name="Statistical Summary", |
|
|
data_type="ECG Data", |
|
|
requires_params=True, |
|
|
param_components=[ |
|
|
{"type": "checkboxgroup", "label": "Select Leads", "elem_id": "ecg_stats_leads"} |
|
|
], |
|
|
output_tabs=["summary", "visualization"] |
|
|
) |
|
|
} |
|
|
} |
|
|
return configs.get(data_type, {}).get(task_name) |
|
|
|
|
|
@staticmethod |
|
|
def get_tasks_for_data_type(data_type: str) -> List[str]: |
|
|
"""Get available tasks for a data type""" |
|
|
tasks = { |
|
|
"EHR Data": ["Near-Duplicate Detection", "Find Mislabeled Data"], |
|
|
"ECG Data": ["ECG Visualization", "Statistical Summary"] |
|
|
} |
|
|
return tasks.get(data_type, []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnalysisExecutor: |
|
|
"""Executes analysis tasks and returns results""" |
|
|
|
|
|
@staticmethod |
|
|
def execute_near_duplicate_detection(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]: |
|
|
"""Execute near-duplicate detection pipeline""" |
|
|
try: |
|
|
if not label: |
|
|
return "⚠ Label column required", {"original": df, "processed": None, "summary": None} |
|
|
|
|
|
bar = tqdm(total=100, leave=False, desc="Pipeline Progress") |
|
|
steps = [ |
|
|
make_step(find_near_duplicates, name="dedup")(progress=bar), |
|
|
make_step(custom_featurizer, name="featurize")( |
|
|
label=label, nan_strategy="impute", on_pipeline_error="drop", progress=bar |
|
|
), |
|
|
make_step(find_issues, name="find_label_issues")(label=label, progress=bar), |
|
|
] |
|
|
results_df, summary_list = run_pipeline(steps, df=df) |
|
|
bar.close() |
|
|
|
|
|
return "✓ Near-duplicate detection completed", { |
|
|
"original": df, "processed": results_df, "summary": summary_list |
|
|
} |
|
|
except Exception as e: |
|
|
return f"✗ Error: {str(e)}", {"original": df, "processed": None, "summary": None} |
|
|
|
|
|
@staticmethod |
|
|
def execute_find_mislabeled(df: pd.DataFrame, label: str) -> Tuple[str, Dict[str, Any]]: |
|
|
"""Execute mislabeled data detection""" |
|
|
try: |
|
|
if not label: |
|
|
return "⚠ Label column required", {"original": df, "summary": None} |
|
|
|
|
|
summary = { |
|
|
"task": "Find Mislabeled Data", "label_column": label, "total_samples": len(df), |
|
|
"suspicious_samples": 0, "message": "Mislabeled detection analysis completed" |
|
|
} |
|
|
return "✓ Mislabeled data analysis completed", {"original": df, "summary": summary} |
|
|
except Exception as e: |
|
|
return f"✗ Error: {str(e)}", {"original": df, "summary": None} |
|
|
|
|
|
@staticmethod |
|
|
def execute_ecg_visualization(df: pd.DataFrame, leads: List[str] = None, viz_types: List[str] = None) -> Tuple[str, Dict[str, Any]]: |
|
|
"""Execute ECG visualization using ECGAnalyzer""" |
|
|
try: |
|
|
|
|
|
available_leads = ECGAnalyzer.detect_leads(df) |
|
|
|
|
|
|
|
|
if not leads: |
|
|
leads = available_leads if available_leads else [] |
|
|
|
|
|
if not leads: |
|
|
return "⚠ No ECG leads found in data", {"visualization": None, "summary": None} |
|
|
|
|
|
|
|
|
if not viz_types: |
|
|
viz_types = ["Signal Waveform", "Histogram"] |
|
|
|
|
|
|
|
|
viz_html = ECGAnalyzer.create_all_visualizations(df, leads, viz_types) |
|
|
|
|
|
|
|
|
stats = ECGAnalyzer.generate_statistics(df, leads) |
|
|
|
|
|
summary = { |
|
|
"task": "ECG Visualization", |
|
|
"samples": len(df), |
|
|
"leads_analyzed": leads, |
|
|
"visualizations": viz_types, |
|
|
"statistics": stats |
|
|
} |
|
|
|
|
|
return "✓ ECG visualization created", {"visualization": viz_html, "summary": summary} |
|
|
except Exception as e: |
|
|
return f"✗ Error: {str(e)}", {"visualization": None, "summary": None} |
|
|
|
|
|
@staticmethod |
|
|
def execute_statistical_summary(df: pd.DataFrame, leads: List[str] = None) -> Tuple[str, Dict[str, Any]]: |
|
|
"""Execute statistical summary using ECGAnalyzer""" |
|
|
try: |
|
|
|
|
|
available_leads = ECGAnalyzer.detect_leads(df) |
|
|
|
|
|
|
|
|
if not leads: |
|
|
leads = available_leads if available_leads else list(df.select_dtypes(include=[np.number]).columns) |
|
|
|
|
|
if not leads: |
|
|
return "⚠ No numeric columns found", {"summary": None, "visualization": None} |
|
|
|
|
|
|
|
|
stats = ECGAnalyzer.generate_statistics(df, leads) |
|
|
|
|
|
|
|
|
html_rows = [] |
|
|
html_rows.append("<table class='preview-table' style='margin: 20px auto; max-width: 900px;'>") |
|
|
html_rows.append("<thead><tr><th>Lead</th><th>Mean</th><th>Std</th><th>Min</th><th>Q25</th><th>Median</th><th>Q75</th><th>Max</th></tr></thead>") |
|
|
html_rows.append("<tbody>") |
|
|
|
|
|
for lead, lead_stats in stats.items(): |
|
|
html_rows.append(f"<tr>") |
|
|
html_rows.append(f"<td><strong>{lead}</strong></td>") |
|
|
html_rows.append(f"<td>{lead_stats['mean']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['std']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['min']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['q25']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['median']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['q75']:.4f}</td>") |
|
|
html_rows.append(f"<td>{lead_stats['max']:.4f}</td>") |
|
|
html_rows.append(f"</tr>") |
|
|
|
|
|
html_rows.append("</tbody></table>") |
|
|
summary_html = f"<div style='overflow-x:auto;'><h3 style='text-align:center;'>Statistical Summary</h3>{''.join(html_rows)}</div>" |
|
|
|
|
|
summary = { |
|
|
"task": "Statistical Summary", |
|
|
"rows": len(df), |
|
|
"leads_analyzed": leads, |
|
|
"statistics": stats |
|
|
} |
|
|
|
|
|
return "✓ Statistical summary generated", {"summary": summary, "visualization": summary_html} |
|
|
except Exception as e: |
|
|
return f"✗ Error: {str(e)}", {"summary": None, "visualization": None} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UIManager: |
|
|
"""Manages UI state and dynamic updates""" |
|
|
|
|
|
def __init__(self): |
|
|
self.current_df = None |
|
|
self.current_data_type = "EHR Data" |
|
|
self.chatbot_context = {} |
|
|
self.command_map = { |
|
|
"data_type": {"ehr": "EHR Data", "ecg": "ECG Data"}, |
|
|
"task": { |
|
|
"deduplication": "Near-Duplicate Detection", "mislabeled": "Find Mislabeled Data", |
|
|
"visualize_ecg": "ECG Visualization", "stats": "Statistical Summary" |
|
|
} |
|
|
} |
|
|
|
|
|
def load_csv(self, file) -> Tuple[str, Optional[pd.DataFrame]]: |
|
|
"""Load CSV file""" |
|
|
if file is None: return "⚠ No file uploaded", None |
|
|
try: |
|
|
df = pd.read_csv(file.name) |
|
|
self.current_df = df |
|
|
return f"✓ Loaded {len(df)} rows, {len(df.columns)} columns", df |
|
|
except Exception as e: |
|
|
return f"✗ Error: {str(e)}", None |
|
|
|
|
|
def on_file_upload(self, file, data_type: str): |
|
|
"""Handle file upload - returns updates for all components""" |
|
|
status, df = self.load_csv(file) |
|
|
if df is None: |
|
|
return ( |
|
|
status, gr.update(value=None), gr.update(choices=[], value=[]), |
|
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
gr.update(choices=[]), gr.update(choices=[]), |
|
|
gr.update(choices=[]), gr.update(choices=[], value=[]), |
|
|
gr.update(choices=[]), gr.update(choices=[], value=[]), |
|
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
) |
|
|
|
|
|
self.chatbot_context = {"file": file.name, "type": data_type, "df": df} |
|
|
available_tasks = TaskRegistry.get_tasks_for_data_type(data_type) |
|
|
col_choices = list(df.columns) |
|
|
|
|
|
|
|
|
ecg_leads = ECGAnalyzer.detect_leads(df) if data_type == "ECG Data" else [] |
|
|
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"] |
|
|
|
|
|
return ( |
|
|
status, gr.update(value=df.head(200)), gr.update(choices=available_tasks, value=[], interactive=True), |
|
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
gr.update(choices=col_choices, value=None), gr.update(choices=col_choices, value=None), |
|
|
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]), |
|
|
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]), |
|
|
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
) |
|
|
|
|
|
def on_data_type_change(self, data_type: str, file): |
|
|
"""Handle data type change""" |
|
|
self.current_data_type = data_type |
|
|
if file and self.current_df is not None: |
|
|
self.chatbot_context["type"] = data_type |
|
|
|
|
|
|
|
|
ecg_leads = ECGAnalyzer.detect_leads(self.current_df) if data_type == "ECG Data" else [] |
|
|
viz_types = ["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"] |
|
|
|
|
|
return ( |
|
|
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]), |
|
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform", "Histogram"]), |
|
|
gr.update(choices=ecg_leads, value=ecg_leads), gr.update(choices=viz_types, value=["Signal Waveform"]), |
|
|
f"Data type changed to: {data_type}" |
|
|
) |
|
|
|
|
|
return ( |
|
|
gr.update(choices=TaskRegistry.get_tasks_for_data_type(data_type), value=[]), |
|
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
|
|
gr.update(), gr.update(), gr.update(), gr.update(), |
|
|
f"Data type changed to: {data_type}" |
|
|
) |
|
|
|
|
|
def on_tasks_change(self, selected_tasks: List[str]): |
|
|
"""Handle task selection change - show/hide parameter groups""" |
|
|
show_ndd = "Near-Duplicate Detection" in selected_tasks |
|
|
show_mislabel = "Find Mislabeled Data" in selected_tasks |
|
|
show_ecg_viz = "ECG Visualization" in selected_tasks |
|
|
show_ecg_stats = "Statistical Summary" in selected_tasks and self.current_data_type == "ECG Data" |
|
|
return ( |
|
|
gr.update(visible=show_ndd), |
|
|
gr.update(visible=show_mislabel), |
|
|
gr.update(visible=show_ecg_viz), |
|
|
gr.update(visible=show_ecg_stats) |
|
|
) |
|
|
|
|
|
def process_analysis(self, file, data_type: str, selected_tasks: List[str], |
|
|
ndd_label: str, mislabel_label: str, |
|
|
ecg_viz_leads: List[str], ecg_viz_types: List[str], |
|
|
ecg_stats_leads: List[str]): |
|
|
"""Process analysis tasks based on UI inputs.""" |
|
|
status, df = self.load_csv(file) |
|
|
if df is None: |
|
|
return (status, None, None, None, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)) |
|
|
|
|
|
params = { |
|
|
"ndd_label": ndd_label, |
|
|
"mislabel_label": mislabel_label, |
|
|
"ecg_viz_leads": ecg_viz_leads, |
|
|
"ecg_viz_types": ecg_viz_types, |
|
|
"ecg_stats_leads": ecg_stats_leads |
|
|
} |
|
|
return self._run_analysis(df, data_type, selected_tasks, params) |
|
|
|
|
|
def _run_analysis(self, df: pd.DataFrame, data_type: str, selected_tasks: List[str], params: Dict[str, Any]): |
|
|
"""Centralized analysis executor, callable from UI or chatbot.""" |
|
|
if not selected_tasks: |
|
|
return ( |
|
|
"⚠ No tasks selected", df.head(200), None, None, None, |
|
|
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
) |
|
|
|
|
|
all_tabs = set(); all_results = {"original": df.head(200), "processed": None, "summary": [], "visualization": ""} |
|
|
status_messages = []; executor = AnalysisExecutor() |
|
|
|
|
|
for task_name in selected_tasks: |
|
|
config = TaskRegistry.get_config(data_type, task_name) |
|
|
if not config: |
|
|
status_messages.append(f"✗ Unknown task: {task_name}"); continue |
|
|
all_tabs.update(config.output_tabs) |
|
|
|
|
|
if task_name == "Near-Duplicate Detection": |
|
|
status_msg, results = executor.execute_near_duplicate_detection(df, params.get("ndd_label")) |
|
|
elif task_name == "Find Mislabeled Data": |
|
|
status_msg, results = executor.execute_find_mislabeled(df, params.get("mislabel_label")) |
|
|
elif task_name == "ECG Visualization": |
|
|
status_msg, results = executor.execute_ecg_visualization( |
|
|
df, |
|
|
params.get("ecg_viz_leads"), |
|
|
params.get("ecg_viz_types") |
|
|
) |
|
|
elif task_name == "Statistical Summary": |
|
|
if data_type == "ECG Data": |
|
|
status_msg, results = executor.execute_statistical_summary(df, params.get("ecg_stats_leads")) |
|
|
else: |
|
|
status_msg, results = executor.execute_statistical_summary(df) |
|
|
else: |
|
|
status_msg, results = "✗ Task not implemented", {} |
|
|
|
|
|
status_messages.append(f"{task_name}: {status_msg}") |
|
|
if results.get("processed") is not None: all_results["processed"] = results["processed"] |
|
|
if results.get("visualization"): all_results["visualization"] += results["visualization"] |
|
|
if results.get("summary") is not None: all_results["summary"].append({"task": task_name, "data": results["summary"]}) |
|
|
|
|
|
self.chatbot_context["summary"] = all_results["summary"] or None |
|
|
self.chatbot_context["visualization"] = all_results["visualization"] or None |
|
|
|
|
|
return ( |
|
|
"\n".join(status_messages), all_results["original"], all_results["processed"], |
|
|
all_results["summary"] or None, all_results["visualization"] or None, |
|
|
gr.update(visible="original" in all_tabs), gr.update(visible="processed" in all_tabs), |
|
|
gr.update(visible="summary" in all_tabs), gr.update(visible="visualization" in all_tabs) |
|
|
) |
|
|
|
|
|
def chatbot_respond(self, message: str, history: List): |
|
|
"""Handle chatbot messages, parsing for commands or responding to queries.""" |
|
|
history = history or []; df = self.chatbot_context.get("df") |
|
|
|
|
|
summary = json.dumps(self.chatbot_context.get("summary")[-1]) if self.chatbot_context.get("summary") else "" |
|
|
|
|
|
|
|
|
visualization = '' |
|
|
|
|
|
print("history:", history) |
|
|
print("# ============================================================================\n ") |
|
|
print("message:", message) |
|
|
print("# ============================================================================\n ") |
|
|
print("summary:", summary) |
|
|
print("# ============================================================================\n ") |
|
|
print("visualization:", visualization) |
|
|
print("# ============================================================================\n ") |
|
|
|
|
|
ui_updates = tuple([gr.update()] * 9) |
|
|
|
|
|
command = message |
|
|
context = summary + visualization if summary or visualization else "" |
|
|
response = simple_chat(command, context) |
|
|
|
|
|
history.append((message, response)) |
|
|
return (history, "") + ui_updates |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Build the Gradio interface""" |
|
|
ui_manager = UIManager() |
|
|
custom_css = """ |
|
|
* { box-sizing: border-box; } html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; } |
|
|
.gradio-container { height: 100vh !important; max-width: 100% !important; padding: 0 !important; } |
|
|
#app-container { height: 100vh; display: flex; flex-direction: column; padding: 0.75rem; gap: 0.75rem; } |
|
|
#main-row { flex: 1; min-height: 0; display: flex; gap: 0.75rem; } |
|
|
#left-panel { display: flex; flex-direction: column; height: 100%; background: #f9fafb; border-radius: 10px; padding: 0.75rem; gap: 0.5rem; } |
|
|
#task-section { flex: 1; min-height: 0; overflow-y: auto; display: flex; flex-direction: column; gap: 0.5rem; } |
|
|
#middle-panel, #chat-panel { display: flex; flex-direction: column; height: 100%; } |
|
|
#tabs-container { flex: 1; min-height: 0; display: flex; flex-direction: column; } |
|
|
#tabs-container .tabitem { flex: 1; min-height: 0; overflow: auto; } |
|
|
#chat-history { flex: 1; min-height: 0; overflow-y: auto; margin-bottom: 0.5rem; } |
|
|
#chat-input-row { flex-shrink: 0; display: flex; gap: 0.5rem; } |
|
|
.preview-table { border-collapse: collapse; width: 100%; font-size: 0.875rem; } |
|
|
.preview-table th { background-color: #3498db; color: white; padding: 8px; text-align: left; position: sticky; top: 0; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Medical Data Analysis Platform") as demo: |
|
|
with gr.Column(elem_id="app-container"): |
|
|
gr.Markdown("# 🏥 Medical Data Analysis Platform") |
|
|
with gr.Row(): |
|
|
file_input = gr.File(label="Upload CSV", file_types=[".csv"], scale=2) |
|
|
data_type = gr.Dropdown(choices=["EHR Data", "ECG Data"], value="EHR Data", label="Data Type", scale=1) |
|
|
|
|
|
with gr.Row(elem_id="main-row"): |
|
|
with gr.Column(scale=2, elem_id="left-panel"): |
|
|
with gr.Group(elem_id="task-section"): |
|
|
gr.Markdown("#### Analysis Tasks") |
|
|
task_selector = gr.CheckboxGroup(choices=TaskRegistry.get_tasks_for_data_type("EHR Data"), label=None) |
|
|
with gr.Group(visible=False) as ndd_param_group: |
|
|
gr.Markdown("**Near-Duplicate Detection Parameters**") |
|
|
ndd_label_dropdown = gr.Dropdown(choices=[], label="Label Column") |
|
|
with gr.Group(visible=False) as mislabel_param_group: |
|
|
gr.Markdown("**Find Mislabeled Data Parameters**") |
|
|
mislabel_label_dropdown = gr.Dropdown(choices=[], label="Label Column") |
|
|
with gr.Group(visible=False) as ecg_viz_param_group: |
|
|
gr.Markdown("**ECG Visualization Parameters**") |
|
|
ecg_viz_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[]) |
|
|
ecg_viz_types = gr.CheckboxGroup( |
|
|
choices=["Signal Waveform", "Histogram", "Scatter Plot", "Rolling Average"], |
|
|
label="Visualization Types", |
|
|
value=["Signal Waveform", "Histogram"] |
|
|
) |
|
|
with gr.Group(visible=False) as ecg_stats_param_group: |
|
|
gr.Markdown("**Statistical Summary Parameters**") |
|
|
ecg_stats_leads = gr.CheckboxGroup(choices=[], label="Select Leads", value=[]) |
|
|
process_btn = gr.Button("▶ Process", variant="primary") |
|
|
status_output = gr.Textbox(label="Status", interactive=False, lines=2) |
|
|
|
|
|
with gr.Column(scale=7, elem_id="middle-panel"): |
|
|
with gr.Tabs(elem_id="tabs-container"): |
|
|
with gr.TabItem("Original Data", visible=False) as tab_original: |
|
|
original_df_output = gr.DataFrame(interactive=False) |
|
|
with gr.TabItem("Processed Data", visible=False) as tab_processed: |
|
|
processed_df_output = gr.DataFrame(interactive=False) |
|
|
with gr.TabItem("Summary", visible=False) as tab_summary: |
|
|
summary_output = gr.JSON() |
|
|
with gr.TabItem("Visualization", visible=False) as tab_viz: |
|
|
viz_output = gr.HTML() |
|
|
|
|
|
with gr.Column(scale=3, elem_id="chat-panel"): |
|
|
gr.Markdown("### 💬 AI Assistant") |
|
|
chatbot = gr.Chatbot(elem_id="chat-history", height="100%") |
|
|
with gr.Row(elem_id="chat-input-row"): |
|
|
msg_input = gr.Textbox(placeholder="Ask or send a JSON command...", scale=4, container=False) |
|
|
send_btn = gr.Button("Send", scale=1) |
|
|
|
|
|
analysis_outputs = [ |
|
|
status_output, original_df_output, processed_df_output, summary_output, viz_output, |
|
|
tab_original, tab_processed, tab_summary, tab_viz |
|
|
] |
|
|
|
|
|
file_input.change( |
|
|
fn=ui_manager.on_file_upload, inputs=[file_input, data_type], |
|
|
outputs=[status_output, original_df_output, task_selector, |
|
|
ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group, |
|
|
ndd_label_dropdown, mislabel_label_dropdown, |
|
|
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, |
|
|
tab_original, tab_processed, tab_summary, tab_viz] |
|
|
) |
|
|
data_type.change( |
|
|
fn=ui_manager.on_data_type_change, inputs=[data_type, file_input], |
|
|
outputs=[task_selector, ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group, |
|
|
ecg_viz_leads, ecg_viz_types, ecg_stats_leads, ecg_viz_types, status_output] |
|
|
) |
|
|
task_selector.change( |
|
|
fn=ui_manager.on_tasks_change, inputs=[task_selector], |
|
|
outputs=[ndd_param_group, mislabel_param_group, ecg_viz_param_group, ecg_stats_param_group] |
|
|
) |
|
|
process_btn.click( |
|
|
fn=ui_manager.process_analysis, |
|
|
inputs=[file_input, data_type, task_selector, ndd_label_dropdown, mislabel_label_dropdown, |
|
|
ecg_viz_leads, ecg_viz_types, ecg_stats_leads], |
|
|
outputs=analysis_outputs |
|
|
) |
|
|
|
|
|
chat_submit_args = {"fn": ui_manager.chatbot_respond, "inputs": [msg_input, chatbot], "outputs": [chatbot, msg_input] + analysis_outputs} |
|
|
send_btn.click(**chat_submit_args) |
|
|
msg_input.submit(**chat_submit_args) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch(share=False, server_name="0.0.0.0", server_port=7890) |