| """ |
| app.py — AutoML Engineer Agent · Streamlit UI |
| |
| Run: streamlit run app.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import html |
| import io |
| import json |
| import os |
| import re |
| import sys |
| import tempfile |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| import streamlit as st |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from agent.report import ( |
| _build_html, |
| _build_markdown, |
| _generate_next_steps, |
| count_embedded_plots_html, |
| ) |
| from config import OUTPUT_DIR |
|
|
|
|
| def _init_state() -> None: |
| defaults = { |
| "df": None, |
| "filename": "", |
| "result": None, |
| "log_lines": [], |
| "step_cards": [], |
| "pipeline_track": [], |
| "running": False, |
| "error": None, |
| "theme": "dark", |
| "report_export": None, |
| "agent": None, |
| "saved_model_path": None, |
| "inference_bundle": None, |
| "inference_predictions": None, |
| "demo_dataset": "healthcare", |
| } |
| for k, v in defaults.items(): |
| if k not in st.session_state: |
| st.session_state[k] = v |
|
|
|
|
| APP_ROOT = Path(__file__).parent.resolve() |
|
|
|
|
| def _load_demo_json_file() -> dict | None: |
| """Load the full demo snapshot JSON (version, demo_dataset_path, result, pipeline_track, log_lines).""" |
| dataset = st.session_state.get("demo_dataset", "healthcare") |
| candidates = [ |
| APP_ROOT / f"demo_result_{dataset}.json", |
| APP_ROOT.parent / f"demo_result_{dataset}.json", |
| APP_ROOT / "demo_result.json", |
| APP_ROOT.parent / "demo_result.json", |
| ] |
| for path in candidates: |
| if path.is_file(): |
| with open(path, encoding="utf-8") as f: |
| return json.load(f) |
| return None |
|
|
|
|
| def _load_demo_result() -> dict | None: |
| """Return the nested pipeline ``result`` dict from the demo JSON file.""" |
| raw = _load_demo_json_file() |
| if raw is None: |
| return None |
| if isinstance(raw, dict) and "result" in raw: |
| return raw["result"] |
| return raw |
|
|
|
|
| def _hydrate_comparison_dfs(obj: Any) -> None: |
| """Turn JSON list-of-rows into DataFrames wherever comparison_df appears.""" |
| if isinstance(obj, dict): |
| for k, v in list(obj.items()): |
| if k == "comparison_df" and isinstance(v, list): |
| obj[k] = pd.DataFrame(v) |
| else: |
| _hydrate_comparison_dfs(v) |
| elif isinstance(obj, list): |
| for item in obj: |
| _hydrate_comparison_dfs(item) |
|
|
|
|
| def _resolve_plot_paths_relative(obj: Any, base: Path) -> None: |
| """Resolve relative plot_paths entries against base (APP_ROOT).""" |
| if isinstance(obj, dict): |
| pp = obj.get("plot_paths") |
| if isinstance(pp, dict): |
| for pk, pv in list(pp.items()): |
| if isinstance(pv, str) and pv.strip(): |
| pth = Path(pv) |
| if not pth.is_absolute(): |
| cand = (base / pv).resolve() |
| if cand.is_file(): |
| pp[pk] = str(cand) |
| for v in obj.values(): |
| _resolve_plot_paths_relative(v, base) |
| elif isinstance(obj, list): |
| for item in obj: |
| _resolve_plot_paths_relative(item, base) |
|
|
|
|
| def _apply_demo_payload(data: dict) -> None: |
| raw_res = copy.deepcopy(data["result"]) |
| _hydrate_comparison_dfs(raw_res) |
| _resolve_plot_paths_relative(raw_res, APP_ROOT) |
| st.session_state.result = raw_res |
|
|
| track = copy.deepcopy(data.get("pipeline_track", [])) |
| for s in track: |
| d = s.get("data") |
| if isinstance(d, dict): |
| _hydrate_comparison_dfs(d) |
| _resolve_plot_paths_relative(d, APP_ROOT) |
| st.session_state.pipeline_track = track |
|
|
| st.session_state.log_lines = list(data.get("log_lines", [])) |
| st.session_state.step_cards = [] |
| st.session_state.error = None |
| st.session_state.report_export = None |
| st.session_state["agent"] = None |
| st.session_state.running = False |
|
|
| dp = data.get("demo_dataset_path") |
| if dp: |
| p = APP_ROOT / dp |
| if p.is_file(): |
| st.session_state.df = pd.read_csv(p) |
| st.session_state.filename = p.name |
| goal = data.get("demo_goal", "") |
| if goal: |
| st.session_state["user_goal_input"] = goal |
|
|
|
|
| def _on_demo_mode_change() -> None: |
| if st.session_state.get("demo_mode_toggle"): |
| snap = _load_demo_json_file() |
| if not snap: |
| st.session_state["_demo_data_missing"] = True |
| else: |
| st.session_state["_demo_data_missing"] = False |
| st.session_state.pop("_demo_snapshot_error", None) |
| _apply_demo_payload(snap) |
| else: |
| st.session_state.pop("_demo_snapshot_error", None) |
| st.session_state.pop("_demo_data_missing", None) |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| st.session_state.report_export = None |
| st.session_state["agent"] = None |
|
|
|
|
| def _pal() -> dict[str, str]: |
| """UI colors for inline HTML; matches dark/light theme.""" |
| if st.session_state.get("theme", "dark") == "light": |
| return { |
| "text": "#111111", |
| "muted": "#525252", |
| "muted2": "#737373", |
| "accent": "#6d28d9", |
| "accent_soft": "#7c3aed", |
| "blue": "#2563eb", |
| "green": "#15803d", |
| "red": "#dc2626", |
| "amber": "#b45309", |
| "border": "#d4d4d4", |
| "card_bg": "#ffffff", |
| "pre_bg": "#f5f5f5", |
| "section": "#737373", |
| "summary_bg": "#f4f4f5", |
| "summary_border": "#e4e4e7", |
| "empty_icon": "#d4d4d4", |
| "empty_sub": "#6b7280", |
| "empty_body": "#374151", |
| } |
| return { |
| "text": "#e2e0d8", |
| "muted": "#888780", |
| "muted2": "#444441", |
| "accent": "#7c3aed", |
| "accent_soft": "#c084fc", |
| "blue": "#60a5fa", |
| "green": "#4ade80", |
| "red": "#fb7185", |
| "amber": "#fbbf24", |
| "border": "#2a2a2e", |
| "card_bg": "#141416", |
| "pre_bg": "#0a0a0c", |
| "section": "#444441", |
| "summary_bg": "#1a1a22", |
| "summary_border": "#2f2f38", |
| "empty_icon": "#2a2a2e", |
| "empty_sub": "#444441", |
| "empty_body": "#333330", |
| } |
|
|
|
|
| |
| st.set_page_config( |
| page_title="AutoML Engineer", |
| page_icon="⚡", |
| layout="wide", |
| initial_sidebar_state="expanded", |
| ) |
|
|
| _init_state() |
| if "theme_light_toggle" not in st.session_state: |
| st.session_state.theme_light_toggle = st.session_state.get("theme", "dark") == "light" |
| st.session_state.theme = "light" if st.session_state.theme_light_toggle else "dark" |
|
|
| |
| st.markdown(""" |
| <style> |
| @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&family=DM+Sans:wght@300;400;500&display=swap'); |
| |
| html, body, [class*="css"] { |
| font-family: 'DM Sans', sans-serif; |
| } |
| |
| /* Dark theme overrides */ |
| .stApp { |
| background-color: #0e0e10; |
| color: #e2e0d8; |
| } |
| |
| /* Sidebar */ |
| [data-testid="stSidebar"] { |
| background-color: #141416; |
| border-right: 1px solid #2a2a2e; |
| } |
| |
| /* Header strip */ |
| .agent-header { |
| display: flex; |
| align-items: baseline; |
| gap: 12px; |
| padding: 8px 0 20px 0; |
| border-bottom: 1px solid #2a2a2e; |
| margin-bottom: 24px; |
| } |
| .agent-title { |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 22px; |
| font-weight: 600; |
| color: #e2e0d8; |
| letter-spacing: -0.5px; |
| } |
| .agent-badge { |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 11px; |
| background: #1e3a2a; |
| color: #4ade80; |
| padding: 2px 8px; |
| border-radius: 4px; |
| border: 1px solid #2d5e3e; |
| } |
| |
| /* Metric cards */ |
| .metric-grid { |
| display: grid; |
| grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); |
| gap: 12px; |
| margin: 16px 0; |
| } |
| .metric-card { |
| background: #141416; |
| border: 1px solid #2a2a2e; |
| border-radius: 8px; |
| padding: 16px; |
| font-family: 'JetBrains Mono', monospace; |
| } |
| .metric-label { |
| font-size: 10px; |
| color: #888780; |
| text-transform: uppercase; |
| letter-spacing: 1px; |
| margin-bottom: 6px; |
| } |
| .metric-value { |
| font-size: 26px; |
| font-weight: 600; |
| color: #7c3aed; |
| } |
| .metric-value.green { color: #4ade80; } |
| .metric-value.amber { color: #fbbf24; } |
| .metric-value.coral { color: #fb7185; } |
| |
| /* Activity log */ |
| .log-container { |
| background: #0a0a0c; |
| border: 1px solid #2a2a2e; |
| border-radius: 8px; |
| padding: 16px; |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 12px; |
| max-height: 420px; |
| overflow-y: auto; |
| } |
| .log-line { padding: 3px 0; line-height: 1.6; } |
| .log-text { color: #a8a49c; } |
| .log-tool-run { color: #60a5fa; } |
| .log-tool-done { color: #4ade80; } |
| .log-error { color: #fb7185; } |
| .log-ts { color: #444441; margin-right: 8px; } |
| |
| /* Tool chip */ |
| .tool-chip { |
| display: inline-block; |
| background: #1a1a1e; |
| border: 1px solid #3a3a3e; |
| border-radius: 4px; |
| padding: 1px 6px; |
| font-size: 11px; |
| font-family: 'JetBrains Mono', monospace; |
| color: #c084fc; |
| margin: 0 2px; |
| } |
| |
| /* Section heading */ |
| .section-head { |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 11px; |
| text-transform: uppercase; |
| letter-spacing: 2px; |
| color: #444441; |
| margin: 28px 0 12px 0; |
| padding-bottom: 8px; |
| border-bottom: 1px solid #2a2a2e; |
| } |
| |
| /* Feature importance bar */ |
| .fi-bar-bg { |
| background: #1a1a1e; |
| border-radius: 3px; |
| height: 8px; |
| margin-top: 4px; |
| } |
| .fi-bar-fill { |
| height: 8px; |
| border-radius: 3px; |
| background: #7c3aed; |
| } |
| |
| /* Dataframe override */ |
| [data-testid="stDataFrame"] { |
| background: #141416; |
| } |
| |
| /* Buttons — keep label horizontal in narrow columns */ |
| .stButton > button { |
| background: #1e1e22; |
| color: #e2e0d8; |
| border: 1px solid #3a3a3e; |
| border-radius: 6px; |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 13px; |
| padding: 8px 20px; |
| transition: all 0.15s; |
| white-space: nowrap; |
| } |
| .stButton > button:hover { |
| background: #7c3aed; |
| border-color: #7c3aed; |
| color: #fff; |
| } |
| .stButton > button[kind="primary"] { |
| background: #7c3aed; |
| border-color: #7c3aed; |
| color: #fff; |
| } |
| .stButton > button[kind="primary"]:hover { |
| background: #6d28d9; |
| border-color: #6d28d9; |
| } |
| |
| /* File uploader */ |
| [data-testid="stFileUploader"] { |
| background: #141416; |
| border: 1px dashed #3a3a3e; |
| border-radius: 8px; |
| } |
| |
| .stTabs [data-baseweb="tab-list"] { |
| background-color: #141416 !important; |
| border-bottom: 2px solid #3a3a3e !important; |
| padding: 4px 0 0 0 !important; |
| margin-bottom: 16px !important; |
| } |
| |
| .stTabs [data-baseweb="tab"] { |
| color: #aaa8a0 !important; |
| font-size: 14px !important; |
| font-weight: 500 !important; |
| padding: 10px 28px !important; |
| background-color: transparent !important; |
| border-bottom: 3px solid transparent !important; |
| margin-bottom: -2px !important; |
| } |
| |
| .stTabs [aria-selected="true"] { |
| color: #ffffff !important; |
| border-bottom: 3px solid #7c3aed !important; |
| background-color: transparent !important; |
| } |
| |
| .stTabs [data-baseweb="tab"]:hover { |
| color: #ffffff !important; |
| background-color: #1e1e22 !important; |
| border-radius: 4px 4px 0 0 !important; |
| } |
| |
| /* Input */ |
| .stTextArea textarea, .stTextInput input, .stSelectbox select { |
| background: #141416 !important; |
| border-color: #2a2a2e !important; |
| color: #e2e0d8 !important; |
| font-family: 'DM Sans', sans-serif !important; |
| } |
| |
| /* Step cards */ |
| .step-card { |
| background: #141416; |
| border: 1px solid #2a2a2e; |
| border-radius: 8px; |
| padding: 20px; |
| margin-bottom: 20px; |
| font-family: 'DM Sans', sans-serif; |
| } |
| .step-card-failed { |
| background: rgba(251, 113, 133, 0.08); |
| border-color: #fb7185; |
| } |
| .step-card-waiting { |
| opacity: 0.65; |
| } |
| .step-card table { |
| border-collapse: collapse; |
| margin: 12px 0; |
| font-size: 13px; |
| } |
| .step-card th, .step-card td { |
| padding: 8px 12px; |
| border: 1px solid #2a2a2e; |
| text-align: left; |
| } |
| .step-card th { background: #1a1a1e; color: #888780; } |
| .step-card td { color: #e2e0d8; } |
| .step-card-header { |
| display: flex; |
| align-items: center; |
| gap: 12px; |
| margin-bottom: 16px; |
| padding-bottom: 12px; |
| border-bottom: 1px solid #2a2a2e; |
| } |
| .step-badge { |
| font-family: 'JetBrains Mono', monospace; |
| font-size: 10px; |
| padding: 3px 8px; |
| border-radius: 4px; |
| font-weight: 600; |
| } |
| .step-badge-waiting { background: #2a2a2e; color: #888780; border: 1px solid #3a3a3e; } |
| .step-badge-running { background: #1e3a5f; color: #60a5fa; border: 1px solid #2d5e8e; } |
| .step-badge-done { background: #1e3a2a; color: #4ade80; border: 1px solid #2d5e3e; } |
| .step-badge-failed { background: #4a1e1e; color: #fb7185; border: 1px solid #6e2d2d; } |
| .warning-banner { |
| padding: 12px 16px; |
| border-radius: 6px; |
| margin: 12px 0; |
| font-family: 'DM Sans', sans-serif; |
| font-size: 13px; |
| } |
| .warning-banner-amber { background: rgba(251, 191, 36, 0.15); border: 1px solid #fbbf24; color: #fbbf24; } |
| .warning-banner-red { background: rgba(251, 113, 133, 0.15); border: 1px solid #fb7185; color: #fb7185; } |
| |
| /* Hide menu/footer only — do NOT hide the top header or the sidebar toggle becomes inaccessible */ |
| #MainMenu { visibility: hidden; } |
| footer { visibility: hidden; } |
| .block-container { padding-top: 24px; padding-bottom: 40px; } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| if st.session_state.get("theme", "dark") == "light": |
| st.markdown( |
| """ |
| <style> |
| /* Light theme — professional, high contrast */ |
| .stApp { background-color: #ffffff !important; color: #111111 !important; } |
| html, body, [class*="css"] { color: #111111 !important; } |
| [data-testid="stSidebar"] { background-color: #fafafa !important; border-right: 1px solid #d4d4d4 !important; } |
| [data-testid="stSidebar"] * { color: #111111 !important; } |
| .agent-header { border-bottom-color: #d4d4d4 !important; } |
| .agent-title { color: #111111 !important; } |
| .agent-badge { background: #ecfdf5 !important; color: #15803d !important; border-color: #bbf7d0 !important; } |
| .metric-card { background: #ffffff !important; border: 1px solid #d4d4d4 !important; } |
| .metric-label { color: #525252 !important; } |
| .metric-value { color: #6d28d9 !important; } |
| .metric-value.green { color: #15803d !important; } |
| .metric-value.amber { color: #b45309 !important; } |
| .metric-value.coral { color: #dc2626 !important; } |
| .log-container { background: #f5f5f5 !important; border: 1px solid #d4d4d4 !important; } |
| .log-text { color: #525252 !important; } |
| .log-tool-run { color: #2563eb !important; } |
| .log-tool-done { color: #15803d !important; } |
| .log-error { color: #dc2626 !important; } |
| .log-ts { color: #a3a3a3 !important; } |
| .tool-chip { background: #f4f4f5 !important; border-color: #d4d4d4 !important; color: #6d28d9 !important; } |
| .section-head { color: #737373 !important; border-bottom-color: #d4d4d4 !important; } |
| .fi-bar-bg { background: #e5e5e5 !important; } |
| .fi-bar-fill { background: #6d28d9 !important; } |
| [data-testid="stDataFrame"] { background: #ffffff !important; } |
| .stButton > button { background: #f4f4f5 !important; color: #111111 !important; border: 1px solid #d4d4d4 !important; white-space: nowrap !important; } |
| .stButton > button:hover { background: #6d28d9 !important; border-color: #6d28d9 !important; color: #fff !important; } |
| .stButton > button[kind="primary"] { background: #6d28d9 !important; border-color: #6d28d9 !important; color: #fff !important; } |
| [data-testid="stFileUploader"] { background: #ffffff !important; border: 1px dashed #d4d4d4 !important; } |
| .stTabs [data-baseweb="tab-list"] { background: #fafafa !important; border-bottom: 1px solid #d4d4d4 !important; } |
| .stTabs [data-baseweb="tab"] { color: #525252 !important; } |
| .stTabs [aria-selected="true"] { color: #6d28d9 !important; border-bottom-color: #6d28d9 !important; } |
| .stTextArea textarea, .stTextInput input, .stSelectbox select { |
| background: #ffffff !important; border-color: #d4d4d4 !important; color: #111111 !important; |
| } |
| .step-card { background: #ffffff !important; border: 1px solid #d4d4d4 !important; color: #111111 !important; } |
| .step-card-failed { background: rgba(220, 38, 38, 0.06) !important; border-color: #f87171 !important; } |
| .step-card th { background: #f4f4f5 !important; color: #525252 !important; } |
| .step-card td { color: #111111 !important; border-color: #d4d4d4 !important; } |
| .step-card-header { border-bottom-color: #d4d4d4 !important; } |
| .step-badge-waiting { background: #e5e5e5 !important; color: #525252 !important; border-color: #d4d4d4 !important; } |
| .step-badge-running { background: #dbeafe !important; color: #1d4ed8 !important; border-color: #93c5fd !important; } |
| .step-badge-done { background: #dcfce7 !important; color: #15803d !important; border-color: #86efac !important; } |
| .step-badge-failed { background: #fee2e2 !important; color: #dc2626 !important; border-color: #fca5a5 !important; } |
| .warning-banner-amber { background: rgba(180, 83, 9, 0.1) !important; border-color: #d97706 !important; color: #b45309 !important; } |
| .warning-banner-red { background: rgba(220, 38, 38, 0.08) !important; border-color: #f87171 !important; color: #dc2626 !important; } |
| [data-testid="stMetric"] { background: transparent !important; } |
| [data-testid="stMetricValue"] { color: #111111 !important; } |
| [data-testid="stMetricLabel"] { color: #525252 !important; } |
| [data-testid="stExpander"] { background: #fafafa !important; border: 1px solid #d4d4d4 !important; } |
| .stAlert { color: #111111 !important; } |
| [data-testid="stToggle"] label { color: #111111 !important; } |
| [data-testid="stToggle"] label p { color: #111111 !important; } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| STEP_NAMES = { |
| "run_eda": (1, "EDA"), |
| "detect_task": (2, "Task detection"), |
| "preprocess": (3, "Preprocessing"), |
| "plan_training": (4, "Training Plan"), |
| "train_models": (5, "Model training"), |
| "tune_model": (6, "Hyperparameter tuning"), |
| "evaluate_model": (7, "Evaluation"), |
| "final": (8, "Final recommendation"), |
| } |
|
|
|
|
| def _new_pipeline_track() -> list[dict]: |
| """Eight steps: seven tools + final recommendation (waiting until pipeline completes).""" |
| return [ |
| {"step": 1, "name": "run_eda", "label": "EDA", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 2, "name": "detect_task", "label": "Task detection", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 3, "name": "preprocess", "label": "Preprocessing", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 4, "name": "plan_training", "label": "Training Plan", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 5, "name": "train_models", "label": "Model training", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 6, "name": "tune_model", "label": "Hyperparameter tuning", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 7, "name": "evaluate_model", "label": "Evaluation", "status": "waiting", |
| "data": None, "error": None}, |
| {"step": 8, "name": "final", "label": "Final recommendation", "status": "waiting", |
| "data": None, "error": None}, |
| ] |
|
|
|
|
| def _pipeline_track_update_running(name: str, tune_model_name: str | None = None): |
| for s in st.session_state.pipeline_track: |
| if s["name"] == name: |
| s["status"] = "running" |
| if name == "tune_model" and tune_model_name: |
| s["running_detail"] = tune_model_name |
| else: |
| s["running_detail"] = None |
| break |
|
|
|
|
| def _pipeline_track_update_done(name: str, step_data: dict | None): |
| for s in st.session_state.pipeline_track: |
| if s["name"] == name: |
| s["status"] = "done" |
| s["data"] = step_data |
| s.pop("running_detail", None) |
| break |
|
|
|
|
| def _pipeline_track_fail_running(err: str): |
| for s in reversed(st.session_state.pipeline_track): |
| if s["status"] == "running": |
| s["status"] = "failed" |
| s["error"] = err |
| return |
|
|
|
|
| def _pipeline_track_finalize(result: dict): |
| for s in st.session_state.pipeline_track: |
| if s["name"] == "final": |
| s["status"] = "done" |
| s["data"] = result |
| break |
|
|
|
|
| def _load_csv_from_upload(uploaded_file) -> None: |
| """Load a new CSV from Streamlit's UploadedFile into session state.""" |
| if uploaded_file is None: |
| return |
| try: |
| if uploaded_file.name != st.session_state.get("filename", ""): |
| st.session_state.df = pd.read_csv(uploaded_file) |
| st.session_state.filename = uploaded_file.name |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| except Exception as e: |
| st.error(f"Could not read CSV: {e}") |
|
|
|
|
| |
| def _ts() -> str: |
| return time.strftime("%H:%M:%S") |
|
|
| def _log(html: str): |
| st.session_state.log_lines.append(html) |
|
|
| def _log_text(msg: str): |
| _log(f'<div class="log-line log-text"><span class="log-ts">{_ts()}</span>{msg}</div>') |
|
|
| def _log_tool(name: str, status: str, output: str = ""): |
| if status == "running": |
| _log(f'<div class="log-line log-tool-run"><span class="log-ts">{_ts()}</span>' |
| f'▶ running <span class="tool-chip">{name}</span>...</div>') |
| else: |
| short = output[:120].replace("\n", " ") + ("…" if len(output) > 120 else "") |
| _log(f'<div class="log-line log-tool-done"><span class="log-ts">{_ts()}</span>' |
| f'✓ <span class="tool-chip">{name}</span> — {short}</div>') |
|
|
| def _log_error(msg: str): |
| _log(f'<div class="log-line log-error"><span class="log-ts">{_ts()}</span>✗ {msg}</div>') |
|
|
| def _metric_card(label: str, value: str, cls: str = "") -> str: |
| return (f'<div class="metric-card">' |
| f'<div class="metric-label">{label}</div>' |
| f'<div class="metric-value {cls}">{value}</div>' |
| f'</div>') |
|
|
| def _color_for_metric(name: str, val: float) -> str: |
| name = name.lower() |
| if "r2" in name or "auc" in name or "accuracy" in name: |
| return "green" if val >= 0.8 else "amber" if val >= 0.6 else "coral" |
| if "rmse" in name or "mae" in name or "mape" in name: |
| return "amber" |
| return "" |
|
|
|
|
| def _running_step_ui(name: str, running_detail: str | None) -> tuple[str, str]: |
| """Return (spinner label, HTML paragraph) while a pipeline step is running.""" |
| p = _pal() |
| p_style = ( |
| f"color:{p['blue']};font-family:'JetBrains Mono',monospace;font-size:13px;line-height:1.5;" |
| ) |
| if name == "run_eda": |
| return ( |
| "Running exploratory data analysis…", |
| f'<p style="{p_style}">EDA: profiling columns, dtypes, missing values, and distributions…</p>', |
| ) |
| if name == "detect_task": |
| return ( |
| "Detecting task and target column…", |
| f'<p style="{p_style}">Task detection: inferring classification vs regression and the target column…</p>', |
| ) |
| if name == "preprocess": |
| return ( |
| "Building preprocessing pipeline…", |
| f'<p style="{p_style}">Preprocessing: encoding, scaling, and train/test split…</p>', |
| ) |
| if name == "plan_training": |
| return ( |
| "Building training plan…", |
| f'<p style="{p_style}">Training plan: sizing models, metrics, and Optuna budget to your dataset…</p>', |
| ) |
| if name == "train_models": |
| return ( |
| "Training and comparing models…", |
| f'<p style="{p_style}">Training: fitting and comparing multiple models on the training set…</p>', |
| ) |
| if name == "tune_model": |
| model = running_detail or "the best model" |
| return ( |
| f"Hyperparameter tuning (Optuna): {model}…", |
| f'<p style="{p_style}">Hyperparameter tuning · Optuna is optimizing <strong>{model}</strong> ' |
| f"(Bayesian search; multiple trials — may take a minute)…</p>", |
| ) |
| if name == "evaluate_model": |
| return ( |
| "Evaluating model and generating plots…", |
| f'<p style="{p_style}">Evaluation: test-set metrics, confusion matrix / curves, and SHAP…</p>', |
| ) |
| return ( |
| "Running pipeline step…", |
| f'<p style="{p_style}">Executing…</p>', |
| ) |
|
|
|
|
| def _step_header_html(step: int, label: str, status: str) -> str: |
| p = _pal() |
| if status == "waiting": |
| badge_cls, badge_text = "step-badge-waiting", "Waiting" |
| elif status == "running": |
| badge_cls, badge_text = "step-badge-running", "Running…" |
| elif status == "done": |
| badge_cls, badge_text = "step-badge-done", "✓ Done" |
| else: |
| badge_cls, badge_text = "step-badge-failed", "✗ Failed" |
| return ( |
| f'<div class="step-card-header">' |
| f'<span style="font-family:\'JetBrains Mono\',monospace;font-size:14px;color:{p["muted"]};">Step {step}</span>' |
| f'<span style="font-family:\'JetBrains Mono\',monospace;font-size:16px;font-weight:600;color:{p["text"]};">{label}</span>' |
| f'<span class="step-badge {badge_cls}">{badge_text}</span></div>' |
| ) |
|
|
|
|
| def _render_pipeline_step(s: dict): |
| """Render one pipeline step from pipeline_track (waiting | running | done | failed).""" |
| step, name, label, status = s["step"], s["name"], s["label"], s["status"] |
| data, error = s.get("data"), s.get("error") |
| header = _step_header_html(step, label, status) |
| card_cls = "step-card step-card-failed" if status == "failed" else "step-card" |
| if status == "waiting": |
| pm = _pal()["muted"] |
| st.markdown(f'<div class="step-card step-card-waiting">{header}' |
| f'<p style="color:{pm};font-family:\'DM Sans\',sans-serif;">Waiting for previous steps…</p></div>', |
| unsafe_allow_html=True) |
| return |
| if status == "running": |
| spin_msg, run_html = _running_step_ui(name, s.get("running_detail")) |
| st.markdown(f'<div class="{card_cls}">{header}</div>', unsafe_allow_html=True) |
| with st.spinner(spin_msg): |
| st.markdown(run_html, unsafe_allow_html=True) |
| return |
| if status == "failed" and error: |
| pr = _pal()["red"] |
| st.markdown( |
| f'<div class="{card_cls}">{header}' |
| f'<pre style="color:{pr};font-family:\'JetBrains Mono\',monospace;font-size:12px;white-space:pre-wrap;">' |
| f'{str(error)[:4000]}</pre></div>', |
| unsafe_allow_html=True, |
| ) |
| return |
| if status == "done" and data is not None: |
| _render_step_content(step, name, data, header, card_cls) |
| return |
| if status == "done" and name == "plan_training" and (data is None or not data.get("plan")): |
| st.markdown(f'<div class="{card_cls}">{header}</div>', unsafe_allow_html=True) |
| st.markdown( |
| f'<div class="step-card-body">' |
| f'<p style="font-family:\'DM Sans\',sans-serif;color:{_pal()["amber"]};">' |
| "Training plan could not be generated (see activity log). " |
| "Training will continue with default model set and settings." |
| "</p></div>", |
| unsafe_allow_html=True, |
| ) |
| return |
| st.markdown(f'<div class="{card_cls}">{header}</div>', unsafe_allow_html=True) |
|
|
|
|
| def _render_step_content(step: int, name: str, data: dict, header: str, card_cls: str = "step-card"): |
| st.markdown(f'<div class="{card_cls}">{header}</div>', unsafe_allow_html=True) |
| if name == "run_eda" and "eda" in data: |
| _render_step_1_eda(data["eda"]) |
| elif name == "detect_task" and "task" in data: |
| _render_step_2_task(data["task"]) |
| elif name == "preprocess" and "prep" in data: |
| _render_step_3_prep(data["prep"]) |
| elif name == "plan_training" and data.get("plan"): |
| _render_step_4_plan_training(data["plan"]) |
| elif name == "train_models" and "train" in data: |
| _render_step_4_train(data["train"]) |
| elif name == "tune_model" and "tune" in data: |
| _render_step_4b_tune(data["tune"]) |
| elif name == "evaluate_model" and "eval" in data: |
| _render_step_5_eval(data["eval"]) |
| elif name == "final": |
| _render_step_6_final(data) |
|
|
|
|
| def _render_step_1_eda(eda: dict): |
| ov = eda.get("overview", {}) |
| miss = eda.get("missing", {}) |
| cols_prof = eda.get("columns", {}) |
| flags = eda.get("quality_flags", []) |
| recs = eda.get("recommendations", []) |
| target_info = eda.get("target_info") |
|
|
| n_rows = ov.get("rows", 0) |
| n_cols = ov.get("columns", 0) |
| p = _pal() |
|
|
| html = "" |
| html += f'<p style="font-family:\'JetBrains Mono\',monospace;color:{p["text"]};">Dataset shape: {n_rows:,} rows × {n_cols} columns</p>' |
|
|
| num_cols = [c for c, p in cols_prof.items() if p.get("dtype_group") == "numeric"] |
| cat_cols = [c for c, p in cols_prof.items() if p.get("dtype_group") == "categorical"] |
| html += f'<p><strong>Numeric columns:</strong> {", ".join(num_cols) if num_cols else "none"}</p>' |
| html += f'<p><strong>Categorical columns:</strong> {", ".join(cat_cols) if cat_cols else "none"}</p>' |
| html += f'<p><strong>Duplicate rows:</strong> {ov.get("duplicate_rows", 0)}</p>' |
|
|
| if miss.get("by_column"): |
| rows = [] |
| for col, info in miss["by_column"].items(): |
| pct = info.get("pct", 0) |
| row_style = f"color:{p['red']}" if pct > 30 else f"color:{p['amber']}" if pct > 10 else "" |
| rows.append(f"<tr><td>{col}</td><td>{info.get('count', 0)}</td><td style='{row_style}'>{pct:.1f}%</td></tr>") |
| html += "<table><thead><tr><th>Column</th><th>Missing count</th><th>Missing %</th></tr></thead><tbody>" + "".join(rows) + "</tbody></table>" |
|
|
| if target_info and target_info.get("inferred_task") == "classification": |
| dist = target_info.get("class_distribution", {}) |
| html += "<p><strong>Class distribution</strong></p><table><thead><tr><th>Class</th><th>Count</th></tr></thead><tbody>" |
| for cls, cnt in dist.items(): |
| html += f"<tr><td>{cls}</td><td>{cnt}</td></tr>" |
| html += "</tbody></table>" |
| if "imbalance_ratio" in target_info: |
| html += f"<p><strong>Imbalance ratio:</strong> {target_info['imbalance_ratio']}:1</p>" |
|
|
| skewed = [(c, p.get("skewness")) for c, p in cols_prof.items() |
| if p.get("dtype_group") == "numeric" and p.get("skewness") is not None |
| and abs(float(p["skewness"])) > 2.0] |
| if skewed: |
| html += "<p><strong>Skewed columns (|skewness| > 2.0):</strong></p><ul>" |
| for c, s in skewed: |
| html += f"<li>{c}: skewness = {s:.4f}</li>" |
| html += "</ul>" |
|
|
| if flags: |
| html += "<p><strong>Quality flags</strong></p><ul>" |
| for f in flags: |
| html += f"<li>{f}</li>" |
| html += "</ul>" |
| if recs: |
| html += "<p><strong>Preprocessing recommendations</strong></p><ul>" |
| for r in recs: |
| html += f"<li>{r}</li>" |
| html += "</ul>" |
|
|
| st.markdown(f'<div class="step-card-body">{html}</div>', unsafe_allow_html=True) |
|
|
| if n_rows < 500: |
| st.warning("Small dataset detected — model performance may be unreliable", icon="⚠️") |
| if target_info and target_info.get("inferred_task") == "classification": |
| ratio = target_info.get("imbalance_ratio", 0) |
| if ratio and ratio > 5: |
| st.error("Severe class imbalance detected — consider SMOTE or class_weight=balanced", icon="🚨") |
| high_missing = [col for col, info in (miss.get("by_column") or {}).items() if info.get("pct", 0) > 30] |
| if high_missing: |
| st.warning("High missing rate detected in one or more columns", icon="⚠️") |
|
|
|
|
| def _render_step_2_task(task: dict): |
| html = ( |
| f"<p><strong>Target column:</strong> <code>{task.get('target_col', '—')}</code></p>" |
| f"<p><strong>Task type:</strong> {task.get('task_type', '—')}</p>" |
| f"<p><strong>Confidence:</strong> {task.get('confidence', '—')}</p>" |
| f"<p><strong>Reasoning:</strong> {task.get('reasoning', '—')}</p>" |
| ) |
| if task.get("alternatives"): |
| html += f"<p><strong>Alternative candidate columns:</strong> {', '.join(str(x) for x in task['alternatives'])}</p>" |
| st.markdown(f'<div class="step-card-body">{html}</div>', unsafe_allow_html=True) |
|
|
|
|
| def _render_step_3_prep(prep: dict): |
| html = f"<p><strong>Numeric columns used:</strong> {', '.join(prep.get('num_cols', [])) or 'none'}</p>" |
| html += f"<p><strong>Categorical columns used:</strong> {', '.join(prep.get('cat_cols', [])) or 'none'}</p>" |
| enc = prep.get("encoding_summary", {}) |
| html += "<p><strong>Encoding strategy:</strong></p><ul>" |
| for col, strat in list(enc.items())[:20]: |
| html += f"<li>{col}: {strat}</li>" |
| html += "</ul>" |
| if prep.get("dropped_cols"): |
| html += "<p><strong>Columns dropped (see log for reasons):</strong></p><ul>" |
| for c in prep["dropped_cols"]: |
| html += f"<li>{c}</li>" |
| html += "</ul>" |
| log = prep.get("preprocessing_log", []) |
| html += "<p><strong>Preprocessing log</strong></p>" |
| for line in log: |
| html += f"<p style='font-family:JetBrains Mono,monospace;font-size:12px;'>{line}</p>" |
| n_feat = prep.get("final_feature_count") |
| html += f"<p><strong>Final feature count (after encoding):</strong> {n_feat}</p>" |
| if prep.get("train_size") is not None: |
| html += ( |
| f"<p><strong>Train size:</strong> {prep['train_size']} rows · " |
| f"<strong>Test size:</strong> {prep['test_size']} rows</p>" |
| ) |
| st.markdown(f'<div class="step-card-body">{html}</div>', unsafe_allow_html=True) |
| leak = prep.get("target_leakage_suspicion") |
| if leak: |
| m = re.search(r"Column '([^']+)'", leak) or re.search(r"column ([A-Za-z0-9_]+)", leak, re.I) |
| col = m.group(1) if m else "unknown" |
| st.error( |
| f"Potential target leakage detected in column {col} — this may inflate your metrics", |
| icon="🚨", |
| ) |
|
|
|
|
| def _esc(s: object) -> str: |
| return html.escape(str(s), quote=True) if s is not None else "" |
|
|
|
|
| def _plan_dataset_size_label(n_rows: int) -> str: |
| if n_rows < 1000: |
| return "Small (< 1000 rows)" |
| if n_rows <= 10000: |
| return "Medium" |
| return "Large (> 10000 rows)" |
|
|
|
|
| def _plan_tuning_budget_why(n_rows: int) -> str: |
| if n_rows < 1000: |
| return ( |
| f"Only {n_rows} rows — fewer Optuna trials and a short timeout to limit " |
| "overfitting risk and keep the UI responsive." |
| ) |
| if n_rows <= 10000: |
| return ( |
| f"At {n_rows} rows, a mid-sized budget balances search quality with runtime." |
| ) |
| return ( |
| f"Large dataset ({n_rows:,} rows) — a higher trial count and longer timeout " |
| "let Optuna explore the hyperparameter space properly." |
| ) |
|
|
|
|
| def _plan_why_included(model_name: str, plan: dict) -> str: |
| """One-line, dataset-specific rationale for including a model.""" |
| dp = plan.get("dataset_profile") or {} |
| n = int(dp.get("n_rows", 0)) |
| nf = int(dp.get("n_features", 0)) |
| ir = float(dp.get("imbalance_ratio", 1.0) or 1.0) |
| is_small = bool(dp.get("is_small")) |
| is_large = bool(dp.get("is_large")) |
| is_wide = bool(dp.get("is_wide")) |
| is_bin = bool(dp.get("is_binary")) |
| smote = bool(dp.get("smote_applied")) |
| adj = (plan.get("adjusted_params") or {}).get(model_name, {}) |
| pm = plan.get("primary_metric") or "" |
|
|
| if model_name == "Logistic Regression": |
| if is_bin and ir > 2 and not smote: |
| return ( |
| f"Included: your {n:,} rows show a {ir:.2f}:1 class ratio — " |
| f"a linear model with balanced weights gives a clear, fast baseline before trees." |
| ) |
| return ( |
| f"Included: with {n:,} rows and {nf} encoded features, " |
| f"logistic regression is a strong, interpretable baseline for comparison." |
| ) |
| if model_name == "Linear Regression": |
| return ( |
| f"Included: {n:,} rows × {nf} features — a closed-form linear fit " |
| f"anchors the leaderboard before non-linear models." |
| ) |
| if model_name == "Random Forest": |
| parts = [f"Included: {n:,} rows and {nf} features suit tree ensembles that handle mixed data."] |
| if adj.get("max_depth") is not None: |
| parts.append( |
| f" We capped max_depth={adj['max_depth']} because the sample is small " |
| f"({n} rows) to curb overfitting." |
| ) |
| elif is_large and adj.get("n_estimators"): |
| parts.append( |
| f" At this size we use {adj.get('n_estimators', 300)} trees for stable estimates." |
| ) |
| return "".join(parts) |
| if model_name == "XGBoost": |
| if is_small and nf < 5: |
| return "" |
| if is_large and adj.get("tree_method") == "hist": |
| return ( |
| f"Included: {n:,} rows justify gradient boosting; " |
| f"tree_method=hist keeps each trial fast on this volume." |
| ) |
| return ( |
| f"Included: {n:,} rows × {nf} features — XGBoost captures non-linear " |
| f"interactions typical in tabular benchmarks." |
| ) |
| if model_name == "LightGBM": |
| if is_wide: |
| return ( |
| f"Included: {nf} features is relatively wide — LightGBM scales well " |
| f"to many columns on your {n:,} rows." |
| ) |
| if is_large: |
| return ( |
| f"Included: at {n:,} rows, LightGBM is preferred for speed while " |
| f"still fitting strong tree ensembles." |
| ) |
| return ( |
| f"Included: complements other models on {n:,} rows × {nf} features " |
| f"with efficient leaf-wise growth." |
| ) |
| return ( |
| f"Included for this run ({n:,} rows, {nf} features, primary metric {pm})." |
| ) |
|
|
|
|
| def _plan_skip_dataset_hook(plan: dict) -> str: |
| dp = plan.get("dataset_profile") or {} |
| n = int(dp.get("n_rows", 0)) |
| nf = int(dp.get("n_features", 0)) |
| return f"Context: your processed data has {n:,} rows and {nf} features after preprocessing." |
|
|
|
|
| def _render_step_4_plan_training(plan: dict) -> None: |
| """Step 4 — Training Plan card (after preprocess, before model training).""" |
| p = _pal() |
| dp = plan.get("dataset_profile") or {} |
| n_rows = int(dp.get("n_rows", 0)) |
| n_features = int(dp.get("n_features", 0)) |
| is_reg = plan.get("primary_metric") == "r2" |
| ir = float(dp.get("imbalance_ratio", 1.0) or 1.0) |
| smote = bool(dp.get("smote_applied")) |
|
|
| sec = ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:0 0 10px 0;">SECTION 1 — Dataset profile</p>' |
| '<table><thead><tr><th>Field</th><th>Value</th></tr></thead><tbody>' |
| f"<tr><td>Rows</td><td>{_esc(n_rows)}</td></tr>" |
| f"<tr><td>Features</td><td>{_esc(n_features)}</td></tr>" |
| f"<tr><td>Dataset size</td><td>{_esc(_plan_dataset_size_label(n_rows))}</td></tr>" |
| f"<tr><td>Task type</td><td>{'regression' if is_reg else 'classification'}</td></tr>" |
| ) |
| if not is_reg: |
| sec += ( |
| f"<tr><td>Class imbalance ratio</td><td>{_esc(f'{ir:.2f}')}</td></tr>" |
| f"<tr><td>SMOTE applied</td><td>{'Yes' if smote else 'No'}</td></tr>" |
| ) |
| sec += "</tbody></table>" |
|
|
| sec += ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:20px 0 10px 0;">SECTION 2 — Model selection</p>' |
| ) |
|
|
| rec = plan.get("recommended_models") or [] |
| adj_all = plan.get("adjusted_params") or {} |
| for m in rec: |
| params = adj_all.get(m) or {} |
| param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(params.items())) if params else "defaults" |
| reason = _plan_why_included(m, plan) |
| sec += ( |
| f'<p style="font-family:\'DM Sans\',sans-serif;margin:10px 0 4px 0;">' |
| f'<span style="color:{p["green"]};font-size:16px;">✓</span> ' |
| f'<strong style="font-family:\'JetBrains Mono\',monospace;">{_esc(m)}</strong></p>' |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:11px;color:{p["muted"]};margin:0 0 4px 0;">' |
| f"Parameters: {_esc(param_str)}</p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:13px;color:{p["text"]};margin:0 0 12px 0;">' |
| f"{_esc(reason)}</p>" |
| ) |
|
|
| skip = plan.get("skip_models") or [] |
| reasons = plan.get("skip_reasons") or {} |
| hook = _plan_skip_dataset_hook(plan) |
| for m in skip: |
| r = reasons.get(m, "Excluded by training plan rules.") |
| sec += ( |
| f'<p style="font-family:\'DM Sans\',sans-serif;margin:10px 0 4px 0;">' |
| f'<span style="color:{p["red"]};font-size:16px;">✗</span> ' |
| f'<strong style="font-family:\'JetBrains Mono\',monospace;">{_esc(m)}</strong></p>' |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:13px;color:{p["text"]};margin:0 0 4px 0;">' |
| f"{_esc(r)}</p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:12px;color:{p["muted"]};margin:0 0 12px 0;">' |
| f"{_esc(hook)}</p>" |
| ) |
|
|
| pm = plan.get("primary_metric") or "—" |
| mr = plan.get("metric_reasoning") or "" |
| sec += ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:20px 0 10px 0;">SECTION 3 — Evaluation metric</p>' |
| f'<p style="font-family:\'JetBrains Mono\',monospace;color:{p["text"]};">' |
| f"Primary metric: <strong>{_esc(pm)}</strong></p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:13px;">{_esc(mr)}</p>' |
| ) |
|
|
| nt = plan.get("n_trials", "—") |
| to = plan.get("timeout", "—") |
| sec += ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:20px 0 10px 0;">SECTION 4 — Tuning budget</p>' |
| f'<p style="font-family:\'JetBrains Mono\',monospace;color:{p["text"]};">' |
| f"Optuna trials: <strong>{_esc(nt)}</strong> · " |
| f"Timeout: <strong>{_esc(to)}</strong> seconds</p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:13px;">' |
| f"{_esc(_plan_tuning_budget_why(n_rows))}</p>" |
| ) |
|
|
| warns = plan.get("warnings") or [] |
| notes = plan.get("notes") or [] |
| sec += ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:20px 0 10px 0;">SECTION 5 — Warnings and notes</p>' |
| ) |
| for w in warns: |
| sec += ( |
| f'<div class="warning-banner warning-banner-amber" style="margin-bottom:8px;">' |
| f"⚠ {_esc(w)}</div>" |
| ) |
| if notes: |
| sec += f'<ul style="font-family:\'DM Sans\',sans-serif;color:{p["text"]};">' |
| for note in notes: |
| sec += f"<li>{_esc(note)}</li>" |
| sec += "</ul>" |
| elif not warns: |
| sec += ( |
| f'<p style="font-family:\'DM Sans\',sans-serif;color:{p["muted"]};">' |
| "No additional warnings or notes for this plan.</p>" |
| ) |
|
|
| summary = plan.get("plan_summary") or "" |
| sec += ( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["accent_soft"]};' |
| 'margin:20px 0 10px 0;">SECTION 6 — Plan summary</p>' |
| f'<div style="background:{p["summary_bg"]};border:1px solid {p["summary_border"]};border-radius:8px;padding:16px 18px;">' |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:11px;color:{p["muted"]};margin:0 0 8px 0;">' |
| "Agent reasoning</p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:14px;color:{p["text"]};margin:0;line-height:1.55;">' |
| f"{_esc(summary)}</p>" |
| "</div>" |
| ) |
|
|
| st.markdown(f'<div class="step-card-body">{sec}</div>', unsafe_allow_html=True) |
|
|
|
|
| def _gap_threshold_for_task(task_type: str) -> float: |
| return 0.15 if task_type == "classification" else 0.20 |
|
|
|
|
| def _cv_reliability_label(cv_std: float | None) -> tuple[str, str]: |
| if cv_std is None: |
| return "—", _pal()["muted"] |
| if cv_std < 0.02: |
| return "Reliable", _pal()["green"] |
| if cv_std <= 0.05: |
| return "Moderate", _pal()["amber"] |
| return "Unstable", _pal()["red"] |
|
|
|
|
| def _render_step_4_train(train: dict): |
| results = train.get("results", []) |
| comp_df = train.get("comparison_df") |
| best_name = train.get("best_name", "") |
| primary = train.get("metric_name", "roc_auc") |
| task_type = train.get("task_type") or ( |
| "classification" |
| if primary in ("roc_auc", "f1", "f1_weighted", "accuracy") |
| else "regression" |
| ) |
| gap_thr = _gap_threshold_for_task(task_type) |
| overfit_warnings = train.get("overfitting_warnings", []) |
| severe_note = any("All models showed severe overfitting" in w for w in overfit_warnings) |
|
|
| if comp_df is not None and not comp_df.empty: |
| try: |
| if "Gap" in comp_df.columns: |
| gp = _pal() |
|
|
| def _gap_style(s: pd.Series): |
| hi_bg = "rgba(251,113,133,0.25)" if st.session_state.get("theme", "dark") == "dark" else "rgba(220,38,38,0.12)" |
| hi_fg = gp["red"] |
| return [ |
| f"background-color: {hi_bg}; color: {hi_fg}; font-weight: 600;" |
| if isinstance(v, (int, float)) and not pd.isna(v) and float(v) > gap_thr |
| else "" |
| for v in s |
| ] |
|
|
| st.dataframe( |
| comp_df.style.apply(_gap_style, subset=["Gap"]), |
| use_container_width=True, |
| hide_index=True, |
| ) |
| else: |
| st.dataframe(comp_df, use_container_width=True, hide_index=True) |
| except Exception: |
| st.dataframe(comp_df, use_container_width=True, hide_index=True) |
| else: |
| st.info("No comparison table available.") |
|
|
| st.caption( |
| "Cross-validation splits the training data into several folds, trains on all but one each time, " |
| "and scores on the held-out fold. **CV Mean** is the average of those scores — often more reliable " |
| "than a single train/test split because every row is used for validation once." |
| ) |
|
|
| if best_name: |
| base = best_name.replace(" (tuned)", "").strip() |
| best_r = next((r for r in results if r.get("name") == base), None) |
| cv_std = best_r.get("cv_std") if best_r else None |
| lbl, col = _cv_reliability_label(cv_std) |
| reason = ( |
| "Least overfit model (all others exceeded severe gap threshold)." |
| if severe_note |
| else ( |
| "Highest CV mean on the primary metric (when CV ran); otherwise ranked by held-out test score." |
| ) |
| ) |
| st.markdown( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;color:{_pal()["text"]};">' |
| f"<strong>Selected best model:</strong> {best_name} " |
| f'<span style="margin-left:10px;padding:2px 8px;border-radius:4px;font-size:12px;background:rgba(128,128,128,0.15);color:{col};border:1px solid {col};">{lbl}</span>' |
| f" — {reason}</p>", |
| unsafe_allow_html=True, |
| ) |
|
|
| if results and any(r.get("cv_scores") for r in results): |
| st.markdown("**Cross-validation details**") |
| for r in results: |
| scores = r.get("cv_scores") |
| if not scores: |
| continue |
| st.markdown(f"*{r.get('name', 'Model')}* — scores per fold") |
| chart_df = pd.DataFrame( |
| {"Score": [float(s) for s in scores]}, |
| index=[f"Fold {i + 1}" for i in range(len(scores))], |
| ) |
| st.bar_chart(chart_df) |
|
|
| for r in results: |
| m = r.get("metrics", {}) |
| if m.get("overfit"): |
| ts = m.get("train_score", 0) |
| tss = m.get("test_score", 0) |
| gap = m.get("generalization_gap", 0) |
| label = "ROC-AUC" if task_type == "classification" else "R²" |
| st.error( |
| f"Overfitting detected in {r['name']} — train {label} {ts:.4f} vs test {label} {tss:.4f}, gap {gap:.4f}", |
| icon="🚨", |
| ) |
|
|
|
|
| def _render_step_4b_tune(tune: dict): |
| model = tune.get("model_name") or "—" |
| if not tune.get("success", False): |
| err = tune.get("error", "Unknown error") |
| st.markdown( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;color:{_pal()["text"]};">' |
| f"<strong>Model:</strong> {model}</p>", |
| unsafe_allow_html=True, |
| ) |
| st.error(f"Hyperparameter tuning failed: {err}") |
| return |
|
|
| bp = tune.get("best_params") or {} |
| html = ( |
| f"<p><strong>Model tuned:</strong> {model}</p>" |
| f"<p><strong>Baseline score (test):</strong> {tune.get('baseline_score', 0):.4f} · " |
| f"<strong>After tuning:</strong> {tune.get('best_score', 0):.4f} · " |
| f"<strong>Improvement:</strong> {tune.get('improvement', 0):+.4f}</p>" |
| f"<p><strong>Optuna trials:</strong> {tune.get('n_trials_run', 0)} · " |
| f"<strong>Time:</strong> {tune.get('tuning_time_s', 0):.1f}s · " |
| f"<strong>Train–test gap:</strong> {tune.get('generalization_gap', 0):.4f}</p>" |
| ) |
| if bp: |
| html += "<p><strong>Best hyperparameters</strong></p><table><thead><tr><th>Parameter</th><th>Value</th></tr></thead><tbody>" |
| for k, v in sorted(bp.items())[:24]: |
| html += f"<tr><td><code>{k}</code></td><td>{v}</td></tr>" |
| html += "</tbody></table>" |
| st.markdown(f'<div class="step-card-body">{html}</div>', unsafe_allow_html=True) |
| if tune.get("overfit"): |
| st.warning( |
| "Tuned model still shows an elevated train–test gap — consider more data or stronger regularization.", |
| icon="⚠️", |
| ) |
|
|
|
|
| def _is_shap_plot_key(k: str) -> bool: |
| return k in ("shap_bar", "shap_summary", "shap_waterfall") or k.startswith("shap_dependence_") |
|
|
|
|
| def _render_shap_dependence_deep_dive(plot_paths: dict) -> None: |
| """Full-width SHAP dependence plots (dynamic keys shap_dependence_*).""" |
| dep_keys = sorted(k for k in plot_paths if str(k).startswith("shap_dependence_")) |
| if not dep_keys: |
| return |
| st.markdown("---") |
| st.markdown("### SHAP Feature Deep Dive") |
| st.caption( |
| "These plots show how each top feature individually drives the model's predictions across all test samples." |
| ) |
| for key in dep_keys: |
| path = plot_paths.get(key) |
| if not path or not Path(path).exists(): |
| continue |
| feat_label = str(key).replace("shap_dependence_", "", 1).replace("_", " ") |
| st.image(path, use_container_width=True) |
| st.caption( |
| f"Feature: {feat_label} — each point is one test sample. " |
| "Color indicates the interacting feature value." |
| ) |
|
|
|
|
| def _render_shap_bar_and_summary(plot_paths: dict) -> None: |
| """Mean |SHAP| bar and beeswarm summary.""" |
| if not any(plot_paths.get(k) for k in ("shap_bar", "shap_summary")): |
| return |
|
|
| st.markdown("---") |
| st.markdown( |
| '<p style="margin:24px 0 16px 0;">SHAP explainability</p>', |
| unsafe_allow_html=True, |
| ) |
|
|
| pbar = plot_paths.get("shap_bar") |
| if pbar and Path(pbar).exists(): |
| st.markdown("### Feature Importance") |
| st.caption( |
| "Mean absolute SHAP values show which features push predictions the most, on average, " |
| "across the explained samples." |
| ) |
| st.image(pbar, use_container_width=True) |
| st.markdown("<div style='height:24px'></div>", unsafe_allow_html=True) |
|
|
| psum = plot_paths.get("shap_summary") |
| if psum and Path(psum).exists(): |
| st.markdown("### Feature Impact Distribution") |
| st.caption( |
| "Each dot represents a data point. Red = high feature value, Blue = low. " |
| "Position shows impact on prediction." |
| ) |
| st.image(psum, use_container_width=True) |
| st.markdown("<div style='height:20px'></div>", unsafe_allow_html=True) |
|
|
|
|
| def _render_shap_waterfall_only(plot_paths: dict) -> None: |
| pw = plot_paths.get("shap_waterfall") |
| if not pw or not Path(pw).exists(): |
| return |
| st.markdown("#### Local explanation (one example row)") |
| st.image(pw, use_container_width=True) |
| st.markdown("<div style='height:12px'></div>", unsafe_allow_html=True) |
|
|
|
|
| def _render_shap_ui(plot_paths: dict) -> None: |
| """SHAP: bar & summary → dependence deep dive → waterfall.""" |
| if not any(_is_shap_plot_key(k) and plot_paths.get(k) for k in plot_paths): |
| return |
| _render_shap_bar_and_summary(plot_paths) |
| _render_shap_dependence_deep_dive(plot_paths) |
| _render_shap_waterfall_only(plot_paths) |
|
|
|
|
| def _render_step_5_eval(eval_data: dict): |
| metrics = eval_data.get("metrics", {}) |
| plot_paths = eval_data.get("plot_paths", {}) |
| task = eval_data.get("task_type", "classification") |
|
|
| st.markdown( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;color:{_pal()["muted"]};">Evaluation metrics</p>', |
| unsafe_allow_html=True, |
| ) |
| if task == "classification": |
| c1, c2, c3 = st.columns(3) |
| if metrics.get("accuracy") is not None: |
| c1.metric("Accuracy", f"{metrics['accuracy']:.4f}") |
| if metrics.get("f1") is not None: |
| c2.metric("F1 (weighted)", f"{metrics['f1']:.4f}") |
| if metrics.get("roc_auc") is not None: |
| c3.metric("ROC-AUC", f"{metrics['roc_auc']:.4f}") |
| else: |
| c1, c2, c3, c4 = st.columns(4) |
| if metrics.get("r2") is not None: |
| c1.metric("R²", f"{metrics['r2']:.4f}") |
| if metrics.get("rmse") is not None: |
| c2.metric("RMSE", f"{metrics['rmse']:.4f}") |
| if metrics.get("mae") is not None: |
| c3.metric("MAE", f"{metrics['mae']:.4f}") |
| if metrics.get("mape") is not None: |
| c4.metric("MAPE", f"{metrics['mape']:.2f}%") |
|
|
| if task == "classification": |
| auc = metrics.get("roc_auc") |
| if auc is not None and auc < 0.6: |
| st.warning( |
| "Model performance is near random — review your features and target column", |
| icon="⚠️", |
| ) |
| else: |
| r2 = metrics.get("r2") |
| if r2 is not None and r2 < 0.3: |
| st.warning( |
| "Model performance is near random — review your features and target column", |
| icon="⚠️", |
| ) |
|
|
| for plot_name in [ |
| "confusion_matrix", |
| "roc_curve", |
| "actual_vs_predicted", |
| "residuals", |
| "feature_importance", |
| ]: |
| path = plot_paths.get(plot_name) |
| if path and Path(path).exists(): |
| st.image(path, caption=plot_name.replace("_", " ").title(), use_container_width=True) |
|
|
| _render_shap_ui(plot_paths) |
|
|
|
|
| def _feature_interpretation_sentence( |
| feat: str, |
| rank: int, |
| task_type: str, |
| has_shap: bool, |
| ) -> str: |
| strong = rank == 0 |
| if task_type == "classification": |
| return ( |
| f"{feat} was {'the strongest predictor' if strong else 'among the strongest predictors'} — " |
| f"higher values tend to push the prediction toward the positive class " |
| f"({'SHAP and importance agree' if has_shap else 'per model feature importance'})." |
| ) |
| return ( |
| f"{feat} was {'the strongest driver' if strong else 'among the strongest drivers'} of predicted values — " |
| f"larger values are associated with higher predicted outcomes " |
| f"({'consistent with SHAP analysis' if has_shap else 'per feature importance'})." |
| ) |
|
|
|
|
| def _render_what_model_learned(result: dict) -> None: |
| """Top-5 importances + plain-English lines (uses feature_importances + eval SHAP flag).""" |
| fi = result.get("feature_importances") or {} |
| if not fi: |
| return |
| ev = result.get("eval") or {} |
| has_shap = bool(ev.get("has_shap")) |
| task = str(result.get("task_type", "classification") or "classification") |
| top5 = sorted(fi.items(), key=lambda x: -x[1])[:5] |
| if not top5: |
| return |
|
|
| st.markdown("**What the model learned**") |
| names = [t[0] for t in top5] |
| vals = [float(t[1]) for t in top5] |
| chart_df = pd.DataFrame({"importance": vals}, index=names) |
| st.bar_chart(chart_df) |
| for i, (feat, _) in enumerate(top5): |
| st.caption(_feature_interpretation_sentence(feat, i, task, has_shap)) |
|
|
|
|
| def _render_step_6_final(result: dict): |
| best = result.get("best_model_name", "—") |
| metrics = result.get("best_metrics", {}) |
| task = result.get("task_type", "") |
| primary = "roc_auc" if task == "classification" else "r2" |
| train_data = result.get("train", {}) |
| overfit_warnings = list(train_data.get("overfitting_warnings", [])) |
| search_results = result.get("overfitting_search_results") or train_data.get("overfitting_search_results", []) |
|
|
| st.markdown( |
| f'<div class="metric-card" style="margin-bottom:16px;">' |
| f'<div class="metric-label">Best model</div>' |
| f'<div class="metric-value" style="font-size:22px;">{best}</div>' |
| f'<p style="font-family:\'JetBrains Mono\',monospace;margin-top:12px;color:{_pal()["muted"]};">' |
| f'Primary metric ({primary}): {metrics.get(primary, 0):.4f}</p></div>', |
| unsafe_allow_html=True, |
| ) |
| st.markdown( |
| "This model was chosen because it achieved the best test score on the primary metric among models " |
| "that were not excluded for severe overfitting (generalization gap above 0.25). " |
| "If every model was severely overfit, the least overfit model was selected.", |
| ) |
|
|
| st.markdown("**Active warnings from this run**") |
| warnings_list: list[str] = [] |
| eda = result.get("eda", {}) |
| if eda.get("overview", {}).get("rows", 0) < 500: |
| warnings_list.append("Small dataset (<500 rows) — unreliable metrics.") |
| ti = eda.get("target_info") or {} |
| if ti.get("inferred_task") == "classification" and (ti.get("imbalance_ratio") or 0) > 5: |
| warnings_list.append("Severe class imbalance (>5:1).") |
| miss = eda.get("missing", {}).get("by_column", {}) or {} |
| if any(info.get("pct", 0) > 30 for info in miss.values()): |
| warnings_list.append("High missing rate (>30%) in one or more columns.") |
| prep = result.get("prep") or {} |
| if prep.get("target_leakage_suspicion"): |
| warnings_list.append("Possible target leakage (high correlation with target).") |
| if overfit_warnings: |
| warnings_list.extend(overfit_warnings) |
| if not warnings_list: |
| st.caption("No critical warnings.") |
| else: |
| for w in warnings_list: |
| st.warning(w, icon="⚠️") |
|
|
| if search_results and overfit_warnings: |
| st.markdown("**Suggestions for fixing overfitting (web search)**") |
| for r in search_results[:8]: |
| if isinstance(r, dict) and "error" not in r: |
| st.markdown( |
| f"- **{r.get('title', '')}** \n {r.get('snippet', '')[:400]}", |
| ) |
|
|
| _render_what_model_learned(result) |
|
|
| st.markdown("**Recommended next actions**") |
| for a in _generate_next_steps(result): |
| st.markdown(f"- {a}") |
|
|
| md = _build_markdown(result) |
| html = _build_html(result) |
| n_embedded = count_embedded_plots_html(html) |
| html_path = OUTPUT_DIR / "automl_report.html" |
| md_path = OUTPUT_DIR / "automl_report.md" |
|
|
| st.markdown("**Export**") |
| ex1, ex2 = st.columns(2) |
| with ex1: |
| if st.button("Generate & save report to disk", key="btn_save_pipeline_reports", use_container_width=True): |
| try: |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| html_path.write_text(html, encoding="utf-8") |
| md_path.write_text(md, encoding="utf-8") |
| sz_kb = html_path.stat().st_size / 1024.0 |
| st.session_state.report_export = { |
| "html_path": str(html_path.resolve()), |
| "size_kb": sz_kb, |
| "n_plots": n_embedded, |
| } |
| except Exception as ex: |
| st.error(str(ex)) |
| st.session_state.report_export = None |
| with ex2: |
| if st.button("Save model", key="btn_save_model_bundle", use_container_width=True): |
| from predict import save_model |
|
|
| agent = st.session_state.get("agent") |
| r = st.session_state.get("result") |
| if agent is None or r is None or r.get("status") != "complete": |
| st.error("Run a complete pipeline first, then save the model.") |
| else: |
| try: |
| prep = agent._prep_result |
| tr = agent._train_result |
| if ( |
| prep is None |
| or tr is None |
| or prep.get("pipeline") is None |
| or tr.get("best_model") is None |
| ): |
| st.error("Model artifacts are not available. Re-run the pipeline.") |
| else: |
| run_id = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| fn = ( |
| (r.get("prep") or {}).get("feature_names") |
| or prep.get("feature_names") |
| ) |
| X_raw = _original_feature_frame_from_agent(agent) |
| train_size = (r.get("prep") or {}).get("train_size") |
| path = save_model( |
| pipeline=prep["pipeline"], |
| model=tr["best_model"], |
| label_encoder=prep.get("label_encoder"), |
| feature_names=fn, |
| task_type=r["task_type"], |
| target_col=r["target_col"], |
| best_metrics=r["best_metrics"], |
| model_name=r["best_model_name"], |
| run_id=run_id, |
| X_train=X_raw, |
| num_cols=prep.get("num_cols"), |
| cat_cols=prep.get("cat_cols"), |
| n_training_rows=int(train_size) if train_size is not None else None, |
| ) |
| st.session_state["saved_model_path"] = path |
| st.success(f"Model saved to outputs/{run_id}_model.pkl") |
| except Exception as ex: |
| st.error(str(ex)) |
| exp = st.session_state.get("report_export") |
| if exp: |
| st.success( |
| f"Saved HTML report to `{exp['html_path']}` — **{exp['size_kb']:.1f} KB**, " |
| f"**{exp['n_plots']}** plot(s) embedded. " |
| "The HTML report is fully self-contained — all plots are embedded. Safe to email or share." |
| ) |
|
|
| d1, d2 = st.columns(2) |
| with d1: |
| st.download_button( |
| "Download HTML", |
| data=html, |
| file_name="automl_report.html", |
| mime="text/html", |
| use_container_width=True, |
| ) |
| with d2: |
| st.download_button( |
| "Download Markdown", |
| data=md, |
| file_name="automl_report.md", |
| mime="text/markdown", |
| use_container_width=True, |
| ) |
|
|
| preview_lines = md.splitlines()[:50] |
| with st.expander("Preview report summary", expanded=False): |
| st.code("\n".join(preview_lines) + ("\n…" if len(md.splitlines()) > 50 else ""), language="markdown") |
|
|
|
|
| def _original_feature_frame_from_agent(agent) -> pd.DataFrame | None: |
| """Raw feature columns as used before the ColumnTransformer (for save_model stats).""" |
| prep = getattr(agent, "_prep_result", None) |
| task = getattr(agent, "_task_result", None) |
| if prep is None or task is None: |
| return None |
| target = task["target_col"] |
| num = prep.get("num_cols") or [] |
| cat = prep.get("cat_cols") or [] |
| cols = list(num) + list(cat) |
| if not cols: |
| return None |
| X = agent.df.drop(columns=[target], errors="ignore") |
| use = [c for c in cols if c in X.columns] |
| if not use: |
| return None |
| return X[use].copy() |
|
|
|
|
| def _render_inference_tab() -> None: |
| """Inference UI: load model, predict, explain.""" |
| p = _pal() |
| st.markdown( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:18px;color:{p["accent"]};">' |
| f"Inference</p>", |
| unsafe_allow_html=True, |
| ) |
| st.markdown( |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:14px;color:{p["muted"]};">' |
| f"Load a saved model and run predictions on new data.</p>", |
| unsafe_allow_html=True, |
| ) |
|
|
| bundle = st.session_state.get("inference_bundle") |
|
|
| st.markdown("### Load a trained model") |
| saved = st.session_state.get("saved_model_path") |
| if saved and Path(saved).exists(): |
| mname = "session model" |
| ag = st.session_state.get("agent") |
| r = st.session_state.get("result") or {} |
| if r.get("best_model_name"): |
| mname = str(r["best_model_name"]) |
| if st.button(f"Use current session model: {mname}", key="inf_use_session_pkl"): |
| try: |
| from predict import load_model |
|
|
| st.session_state["inference_bundle"] = load_model(saved) |
| st.session_state["inference_predictions"] = None |
| st.success("Model loaded from this session.") |
| st.rerun() |
| except Exception as ex: |
| st.error(str(ex)) |
|
|
| up = st.file_uploader("Upload a saved model (.pkl)", type=["pkl"], key="inf_model_upload") |
| if up is not None: |
| try: |
| from predict import load_model |
|
|
| tmp = Path(tempfile.gettempdir()) / f"inf_model_{up.name}" |
| tmp.write_bytes(up.getvalue()) |
| st.session_state["inference_bundle"] = load_model(str(tmp)) |
| st.session_state["inference_predictions"] = None |
| st.success("Model loaded from file.") |
| except Exception as ex: |
| st.error(str(ex)) |
|
|
| if bundle is None: |
| st.info( |
| "Train a model in the Pipeline tab first, then save it to use here — " |
| "or upload a `.pkl` produced by this app." |
| ) |
| return |
|
|
| from predict import get_model_summary, predict |
|
|
| summary = get_model_summary(bundle) |
| st.markdown( |
| f'<div class="metric-card" style="margin:16px 0;">' |
| f'<span style="font-family:\'JetBrains Mono\',monospace;font-size:11px;color:{p["green"]};">' |
| f"● Model loaded</span>" |
| f'<p style="font-family:\'JetBrains Mono\',monospace;margin:12px 0 4px 0;color:{p["text"]};">' |
| f'<strong>{summary["model_name"]}</strong></p>' |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:13px;color:{p["muted"]};">' |
| f"Task: {summary['task_type']} · Target: {summary['target_col']} · " |
| f"Encoded features: {summary['n_features']} · " |
| f"Input columns: {len(summary['original_features'])}</p>" |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:12px;color:{p["muted"]};">' |
| f"Trained: {summary['training_date']} · Rows in training: {summary['n_training_rows']}</p>" |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:11px;color:{p["muted"]};">' |
| f"Metrics: {', '.join(f'{k}={v}' for k, v in list(summary['metrics'].items())[:8])}</p>" |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:11px;color:{p["muted"]};">' |
| f"Expected columns: {', '.join(summary['expected_input_columns'][:24])}" |
| f"{'…' if len(summary['expected_input_columns']) > 24 else ''}</p></div>", |
| unsafe_allow_html=True, |
| ) |
|
|
| expected = summary["expected_input_columns"] |
| target_col = bundle["target_col"] |
| task_type = bundle["task_type"] |
| num_cols = bundle.get("num_cols") or [] |
| cat_cols = bundle.get("cat_cols") or [] |
| means = bundle.get("feature_means") or {} |
| modes = bundle.get("feature_modes") or {} |
| cats = bundle.get("categorical_uniques") or {} |
|
|
| st.markdown("### Upload new data for prediction") |
| pred_df: pd.DataFrame | None = None |
| csv_up = st.file_uploader("Upload CSV for prediction", type=["csv"], key="inf_csv_pred") |
| csv_ready: pd.DataFrame | None = None |
| if csv_up is not None: |
| csv_ready = pd.read_csv(csv_up) |
| st.dataframe(csv_ready.head(5), use_container_width=True) |
| miss = [c for c in expected if c not in csv_ready.columns] |
| if miss: |
| st.warning(f"Missing expected feature columns (will be imputed): {miss}") |
| if target_col in csv_ready.columns: |
| st.warning( |
| f"Target column **{target_col}** is present — it will be ignored for prediction." |
| ) |
| if st.button("Run prediction on CSV", key="inf_run_csv_pred"): |
| pred_df = csv_ready |
|
|
| n_orig = len(num_cols) + len(cat_cols) if (num_cols or cat_cols) else len(expected) |
| if csv_up is None and n_orig <= 10: |
| st.markdown("**Or enter a single row manually**") |
| vals: dict[str, Any] = {} |
| with st.form("inf_manual_form"): |
| for col in num_cols: |
| if col not in expected: |
| continue |
| default = float(means.get(col, 0.0)) |
| vals[col] = st.number_input( |
| col, |
| value=default, |
| format="%.6f", |
| key=f"inf_num_{col}", |
| ) |
| for col in cat_cols: |
| if col not in expected: |
| continue |
| opts = cats.get(col) or [str(modes.get(col, ""))] |
| vals[col] = st.selectbox(col, opts, key=f"inf_cat_{col}") |
| submitted = st.form_submit_button("Predict") |
| if submitted: |
| pred_df = pd.DataFrame([vals]) |
|
|
| if pred_df is None: |
| return |
|
|
| st.markdown("### Prediction results") |
| try: |
| out_df, fill_log = predict(bundle, pred_df) |
| st.session_state["inference_predictions"] = out_df |
| for line in fill_log: |
| st.caption(line) |
| except Exception as ex: |
| st.error(f"Prediction failed: {ex}") |
| return |
|
|
| out_df = st.session_state["inference_predictions"] |
| if out_df is None: |
| return |
|
|
| pred_col = out_df["prediction"] |
| proba = out_df["probability"] if "probability" in out_df.columns else None |
|
|
| if task_type == "classification": |
| le = bundle.get("label_encoder") |
| classes = list(le.classes_) if le is not None else [] |
| pos = classes[1] if len(classes) >= 2 else None |
| st.markdown( |
| f'<div class="metric-card" style="border-color:{p["accent"]};">' |
| f'<div class="metric-label">Prediction</div>' |
| f'<div class="metric-value" style="font-size:32px;">{pred_col.iloc[0]}</div></div>', |
| unsafe_allow_html=True, |
| ) |
| if proba is not None: |
| conf = float(proba.iloc[0]) * 100.0 |
| st.progress(min(conf / 100.0, 1.0)) |
| st.markdown(f"**Confidence:** {conf:.1f}%") |
| if len(out_df) > 1: |
| tbl = pd.DataFrame( |
| { |
| "Row": range(1, len(out_df) + 1), |
| "Predicted Class": out_df["prediction"].astype(str), |
| } |
| ) |
| if proba is not None: |
| tbl["Confidence %"] = (proba * 100.0).round(2) |
| st.dataframe(tbl, use_container_width=True, hide_index=True) |
| vc = out_df["prediction"].astype(str).value_counts() |
| st.bar_chart(vc) |
| if len(out_df) == 1 and pos is not None: |
| c1 = "#4ade80" if str(pred_col.iloc[0]) == str(pos) else "#fb7185" |
| st.markdown( |
| f'<div style="padding:12px;border-radius:8px;border:2px solid {c1};' |
| f'font-family:\'JetBrains Mono\',monospace;">' |
| f"Class highlighted: <strong>{pred_col.iloc[0]}</strong></div>", |
| unsafe_allow_html=True, |
| ) |
| else: |
| st.markdown( |
| f'<div class="metric-card" style="border-color:{p["accent"]};">' |
| f'<div class="metric-label">{target_col} (predicted)</div>' |
| f'<div class="metric-value" style="font-size:32px;">{float(pred_col.iloc[0]):,.4f}</div></div>', |
| unsafe_allow_html=True, |
| ) |
| if len(out_df) > 1: |
| st.dataframe( |
| pd.DataFrame( |
| { |
| "Row": range(1, len(out_df) + 1), |
| "Predicted Value": out_df["prediction"], |
| } |
| ), |
| use_container_width=True, |
| hide_index=True, |
| ) |
| st.bar_chart(out_df["prediction"]) |
| c1, c2, c3 = st.columns(3) |
| c1.metric("Min", f"{out_df['prediction'].min():.4f}") |
| c2.metric("Mean", f"{out_df['prediction'].mean():.4f}") |
| c3.metric("Max", f"{out_df['prediction'].max():.4f}") |
|
|
| csv_bytes = out_df.to_csv(index=False).encode("utf-8") |
| st.download_button( |
| "Download predictions as CSV", |
| data=csv_bytes, |
| file_name="predictions.csv", |
| mime="text/csv", |
| use_container_width=True, |
| ) |
|
|
| model = bundle["model"] |
| feat_names = bundle.get("feature_names") or [] |
| pred0 = out_df["prediction"].iloc[0] |
| st.markdown( |
| f'<p style="font-family:\'DM Sans\',sans-serif;font-size:15px;color:{p["text"]};">' |
| f"Why did the model predict <strong>{html.escape(str(pred0))}</strong> for row 1? " |
| f"The chart below shows which encoded features pushed the score up or down.</p>", |
| unsafe_allow_html=True, |
| ) |
| shap_ok = False |
| try: |
| import matplotlib.pyplot as plt |
| import shap |
| from predict import prepare_transformed_features |
|
|
| Xt, _ = prepare_transformed_features(bundle, pred_df) |
| n = min(len(Xt), 50) |
| Xt = Xt[:n] |
| explainer = shap.Explainer(model) |
| sv = explainer(Xt) |
| shap.plots.waterfall(sv[0], max_display=12, show=False) |
| st.pyplot(plt.gcf(), clear_figure=True) |
| plt.close("all") |
| names = feat_names[: sv[0].values.shape[0]] if feat_names else [] |
| vals = np.asarray(sv[0].values).ravel() |
| if len(names) >= len(vals): |
| order = np.argsort(np.abs(vals))[-5:][::-1] |
| bits = [f"{names[i]} ({vals[i]:+.4f})" for i in order if i < len(vals)] |
| if bits: |
| st.caption("Top 5 drivers: " + " · ".join(bits)) |
| shap_ok = True |
| except Exception: |
| shap_ok = False |
|
|
| if not shap_ok: |
| st.markdown( |
| f'<p style="font-family:\'JetBrains Mono\',monospace;font-size:12px;color:{p["muted"]};">' |
| f"Feature influence</p>", |
| unsafe_allow_html=True, |
| ) |
| bfi = bundle.get("bundle_feature_importances") or {} |
| if bfi: |
| top = sorted(bfi.items(), key=lambda x: -x[1])[:12] |
| for feat, imp in top: |
| st.markdown(f"- **{feat}** — {imp:.4f}") |
| elif hasattr(model, "feature_importances_"): |
| imp = model.feature_importances_ |
| order = np.argsort(imp)[-5:][::-1] |
| for i in order: |
| if i < len(feat_names): |
| st.markdown(f"- **{feat_names[i]}** — {float(imp[i]):.4f}") |
|
|
|
|
| |
| with st.sidebar: |
| st.markdown( |
| f'<div style="font-family:\'JetBrains Mono\',monospace;font-size:13px;' |
| f'color:{_pal()["accent_soft"]};padding:12px 0 20px 0;letter-spacing:1px;">⚡ AutoML Engineer</div>', |
| unsafe_allow_html=True, |
| ) |
|
|
| st.markdown('<div class="section-head">Demo</div>', unsafe_allow_html=True) |
| st.toggle( |
| "Demo Mode", |
| key="demo_mode_toggle", |
| help="Browse a pre-computed example without an API key.", |
| on_change=_on_demo_mode_change, |
| ) |
| if not st.session_state.get("demo_mode_toggle", False): |
| st.markdown('<div class="section-head">API key</div>', unsafe_allow_html=True) |
| _env_key = (os.getenv("ANTHROPIC_API_KEY") or "").strip() |
| if _env_key: |
| st.caption("Using ANTHROPIC_API_KEY from environment (.env).") |
| else: |
| st.text_input( |
| "ANTHROPIC_API_KEY", |
| type="password", |
| key="anthropic_api_key_input", |
| placeholder="sk-ant-...", |
| ) |
| st.caption("Your key is never stored or logged.") |
|
|
| st.markdown('<div class="section-head">Dataset</div>', unsafe_allow_html=True) |
| uploaded = st.file_uploader( |
| "Upload CSV", type=["csv"], |
| label_visibility="collapsed", |
| key="sidebar_csv_uploader", |
| ) |
| _load_csv_from_upload(uploaded) |
|
|
| |
| st.markdown('<div class="section-head">Or use a sample</div>', unsafe_allow_html=True) |
| col1, col2 = st.columns(2) |
| with col1: |
| if st.button("Titanic", use_container_width=True): |
| st.session_state["demo_dataset"] = "titanic" |
| p = Path("datasets/titanic_demo_synth.csv") |
| if not p.exists(): |
| p = Path("datasets/titanic.csv") |
| if p.exists(): |
| st.session_state.df = pd.read_csv(p) |
| st.session_state.filename = p.name |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| if st.session_state.get("demo_mode_toggle"): |
| _ds = _load_demo_json_file() |
| if _ds: |
| st.session_state["_demo_data_missing"] = False |
| _apply_demo_payload(_ds) |
| else: |
| st.session_state["_demo_data_missing"] = True |
| else: |
| st.warning("datasets/titanic_demo_synth.csv or datasets/titanic.csv not found.") |
| if st.button("Healthcare", use_container_width=True): |
| st.session_state["demo_dataset"] = "healthcare" |
| p = Path("datasets/healthcare_demo_synth.csv") |
| if not p.exists(): |
| p = Path("datasets/sample_healthcare_classification.csv") |
| if p.exists(): |
| st.session_state.df = pd.read_csv(p) |
| st.session_state.filename = p.name |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| if st.session_state.get("demo_mode_toggle"): |
| _ds = _load_demo_json_file() |
| if _ds: |
| st.session_state["_demo_data_missing"] = False |
| _apply_demo_payload(_ds) |
| else: |
| st.session_state["_demo_data_missing"] = True |
| else: |
| st.warning("Run generate_samples.py first.") |
| with col2: |
| if st.button("Diabetes", use_container_width=True): |
| st.session_state["demo_dataset"] = "diabetes" |
| p = Path("datasets/diabetes_sklearn_demo.csv") |
| if not p.exists(): |
| p = Path("datasets/diabetes.csv") |
| if p.exists(): |
| st.session_state.df = pd.read_csv(p) |
| st.session_state.filename = p.name |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| if st.session_state.get("demo_mode_toggle"): |
| _ds = _load_demo_json_file() |
| if _ds: |
| st.session_state["_demo_data_missing"] = False |
| _apply_demo_payload(_ds) |
| else: |
| st.session_state["_demo_data_missing"] = True |
| else: |
| st.warning("datasets/diabetes_sklearn_demo.csv or datasets/diabetes.csv not found.") |
| if st.button("Housing", use_container_width=True): |
| st.session_state["demo_dataset"] = "housing" |
| p = Path("datasets/sample_housing_regression.csv") |
| if p.exists(): |
| st.session_state.df = pd.read_csv(p) |
| st.session_state.filename = p.name |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = [] |
| st.session_state.error = None |
| if st.session_state.get("demo_mode_toggle"): |
| _ds = _load_demo_json_file() |
| if _ds: |
| st.session_state["_demo_data_missing"] = False |
| _apply_demo_payload(_ds) |
| else: |
| st.session_state["_demo_data_missing"] = True |
| else: |
| st.warning("Run generate_samples.py first.") |
|
|
| |
| if st.session_state.df is not None: |
| df = st.session_state.df |
| st.markdown('<div class="section-head">Preview</div>', unsafe_allow_html=True) |
| st.markdown( |
| f'<div style="font-family:\'JetBrains Mono\',monospace;font-size:11px;' |
| f'color:{_pal()["muted"]};margin-bottom:8px;">' |
| f'{len(df):,} rows · {len(df.columns)} cols · ' |
| f'{df.isnull().sum().sum()} nulls</div>', |
| unsafe_allow_html=True, |
| ) |
| st.dataframe(df.head(6), use_container_width=True, height=180) |
|
|
| st.markdown('<div class="section-head">Settings</div>', unsafe_allow_html=True) |
| _icon = "☀️" if st.session_state.get("theme", "dark") == "light" else "🌙" |
| st.toggle( |
| _icon, |
| key="theme_light_toggle", |
| help="Switch between light and dark theme", |
| ) |
|
|
|
|
| if st.session_state.get("demo_mode_toggle") and st.session_state.get("result") is None: |
| _snap = _load_demo_json_file() |
| if _snap: |
| st.session_state["_demo_data_missing"] = False |
| _apply_demo_payload(_snap) |
| else: |
| st.session_state["_demo_data_missing"] = True |
|
|
|
|
| tab_pipeline, tab_inference = st.tabs(["Pipeline", "Inference"]) |
| with tab_pipeline: |
| _demo_on = st.session_state.get("demo_mode_toggle", False) |
| if _demo_on: |
| _dname = st.session_state.get("demo_dataset", "healthcare") |
| _label = str(_dname).replace("_", " ").title() |
| st.info( |
| f"Demo Mode — showing pre-computed results for the {_label} dataset. " |
| "Toggle off to use your own API key." |
| ) |
| if st.session_state.get("_demo_data_missing"): |
| st.warning( |
| "Demo data not found. Run `python generate_all_demos.py` locally " |
| "and push the demo_result_*.json files to GitHub." |
| ) |
|
|
| goal_col, btn_col = st.columns([5, 1]) |
| with goal_col: |
| user_goal = st.text_input( |
| "What do you want to predict?", |
| placeholder='e.g. "predict whether a patient will be readmitted"', |
| key="user_goal_input", |
| ) |
| with btn_col: |
| st.write("") |
| run_clicked = st.button( |
| "▶ Run", |
| type="primary", |
| use_container_width=True, |
| key="run_button", |
| disabled=st.session_state.running or _demo_on, |
| ) |
| if _demo_on: |
| st.caption("Add your ANTHROPIC_API_KEY to run on your own data") |
|
|
| |
| if st.session_state.df is None: |
| st.markdown('<div class="section-head">Upload dataset</div>', unsafe_allow_html=True) |
| st.caption( |
| "Drag and drop a CSV here, or use the file browser. " |
| "You can also upload from the left sidebar, or load a sample dataset there." |
| ) |
| main_upload = st.file_uploader( |
| "Choose a CSV file", |
| type=["csv"], |
| accept_multiple_files=False, |
| key="main_csv_uploader", |
| label_visibility="collapsed", |
| ) |
| _load_csv_from_upload(main_upload) |
| |
| |
| if run_clicked: |
| if st.session_state.df is None: |
| st.error("Upload a CSV first (or load a sample dataset from the sidebar).") |
| elif not user_goal.strip(): |
| st.error("Describe your goal so the agent knows what to predict.") |
| else: |
| _env_k = (os.getenv("ANTHROPIC_API_KEY") or "").strip() |
| _paste_k = (st.session_state.get("anthropic_api_key_input") or "").strip() |
| _effective_key = _env_k or _paste_k |
| if not _effective_key: |
| st.error("Add ANTHROPIC_API_KEY to your .env file or paste it in the sidebar.") |
| else: |
| os.environ["ANTHROPIC_API_KEY"] = _effective_key |
| st.session_state.running = True |
| st.session_state.result = None |
| st.session_state.log_lines = [] |
| st.session_state.step_cards = [] |
| st.session_state.pipeline_track = _new_pipeline_track() |
| st.session_state.error = None |
| st.session_state.report_export = None |
|
|
| df = st.session_state.df.copy() |
| log_placeholder = st.empty() |
| status_placeholder = st.empty() |
| step_cards_placeholder = st.empty() |
| agent_holder: dict = {"agent": None} |
|
|
| def run_agent_events(): |
| from agent.core import AutoMLAgent |
|
|
| agent = AutoMLAgent(df, user_goal) |
| agent_holder["agent"] = agent |
| yield from agent.run() |
|
|
| try: |
| status_placeholder.info("▶ Pipeline running… (this may take a minute)") |
| pipeline_failed = False |
| for event in run_agent_events(): |
| etype = event["type"] |
| if etype == "text": |
| _log_text(event["content"]) |
| elif etype == "tool": |
| name = event["name"] |
| status = event["status"] |
| _log_tool(name, status, event.get("output", "")) |
| if status == "running": |
| _pipeline_track_update_running( |
| name, event.get("tune_model_name"), |
| ) |
| elif status == "done": |
| _pipeline_track_update_done(name, event.get("step_data")) |
| elif etype == "error": |
| _log_error(event["content"]) |
| st.session_state.error = event["content"] |
| _pipeline_track_fail_running(event["content"]) |
| pipeline_failed = True |
| elif etype == "done": |
| st.session_state.result = event["result"] |
| _log_text("Pipeline complete.") |
| _pipeline_track_finalize(event["result"]) |
|
|
| log_html = "".join(st.session_state.log_lines) |
| log_placeholder.markdown( |
| f'<div class="log-container">{log_html}</div>', |
| unsafe_allow_html=True, |
| ) |
| with step_cards_placeholder.container(): |
| for s in st.session_state.pipeline_track: |
| _render_pipeline_step(s) |
|
|
| if pipeline_failed: |
| break |
|
|
| if not pipeline_failed: |
| st.session_state["agent"] = agent_holder.get("agent") |
|
|
| status_placeholder.empty() |
| except Exception as e: |
| err_msg = str(e) |
| st.session_state.error = err_msg |
| _log_error(err_msg) |
| status_placeholder.error(f"Pipeline failed: {err_msg}") |
| st.error(f"Pipeline failed: {err_msg}") |
|
|
| st.session_state.running = False |
| st.rerun() |
| |
| |
| |
| if st.session_state.pipeline_track: |
| st.markdown('<div class="section-head">Pipeline steps</div>', unsafe_allow_html=True) |
| for s in st.session_state.pipeline_track: |
| _render_pipeline_step(s) |
| |
| |
| if st.session_state.log_lines: |
| st.markdown('<div class="section-head">Activity log</div>', unsafe_allow_html=True) |
| log_html = "".join(st.session_state.log_lines) |
| st.markdown( |
| f'<div class="log-container">{log_html}</div>', |
| unsafe_allow_html=True, |
| ) |
| |
| if st.session_state.error: |
| st.error(st.session_state.error) |
| |
| |
| |
| with st.expander("🔧 Debug", expanded=False): |
| st.write("**Button:**", "Run clicked" if run_clicked else "Not clicked") |
| result_debug = st.session_state.get("result") |
| st.write("**Result exists:**", result_debug is not None) |
| if result_debug is not None: |
| keys = list(result_debug.keys()) if isinstance(result_debug, dict) else str(type(result_debug)) |
| st.write("**Result keys:**", keys) |
| st.write("**Result status:**", result_debug.get("status") if isinstance(result_debug, dict) else "N/A") |
| st.write("**Running:**", st.session_state.get("running", False)) |
| st.write("**Log lines count:**", len(st.session_state.get("log_lines", []))) |
| |
| |
| |
| result = st.session_state.get("result") |
| if result and result.get("status") == "complete": |
| |
| st.markdown('<div class="section-head">Results</div>', unsafe_allow_html=True) |
| |
| |
| metrics = result.get("metrics", {}) |
| task = result.get("task_type", "") |
| best = result.get("best_model_name", "—") |
| |
| cards_html = '<div class="metric-grid">' |
| cards_html += _metric_card("Best model", best) |
| cards_html += _metric_card("Task", task.capitalize()) |
| cards_html += _metric_card("Target", result.get("target_col", "—")) |
| |
| if task == "classification": |
| for key in ("accuracy", "f1", "roc_auc"): |
| val = metrics.get(key) |
| if val is not None: |
| label = {"accuracy": "Accuracy", "f1": "F1 (weighted)", |
| "roc_auc": "ROC-AUC"}[key] |
| cls = _color_for_metric(key, val) |
| cards_html += _metric_card(label, f"{val:.3f}", cls) |
| else: |
| for key in ("r2", "rmse", "mape"): |
| val = metrics.get(key) |
| if val is not None: |
| label = {"r2": "R²", "rmse": "RMSE", "mape": "MAPE %"}[key] |
| suffix = "%" if key == "mape" else "" |
| cls = _color_for_metric(key, val) |
| fmt = f"{val:.2f}{suffix}" if key in ("rmse", "mape") else f"{val:.4f}" |
| cards_html += _metric_card(label, fmt, cls) |
| |
| cards_html += '</div>' |
| st.markdown(cards_html, unsafe_allow_html=True) |
| |
| |
| tab_model, tab_plots, tab_features, tab_data = st.tabs( |
| ["Model comparison", "Plots", "Features", "Data profile"] |
| ) |
| |
| |
| with tab_model: |
| comp_df = result.get("comparison_df") |
| if comp_df is not None: |
| st.dataframe(comp_df, use_container_width=True, hide_index=True) |
| |
| train_info = result.get("train", {}) |
| log_lines = train_info.get("training_log", []) |
| if log_lines: |
| with st.expander("Training log", expanded=False): |
| st.code("\n".join(log_lines), language=None) |
| |
| |
| with tab_plots: |
| plot_paths = result.get("plot_paths", {}) |
| if not plot_paths: |
| st.info("No plots generated.") |
| else: |
| base_order = [ |
| "confusion_matrix", |
| "roc_curve", |
| "actual_vs_predicted", |
| "residuals", |
| "feature_importance", |
| ] |
| ordered = [p for p in base_order if p in plot_paths] |
| ordered += [ |
| p for p in plot_paths |
| if p not in ordered and not _is_shap_plot_key(p) |
| ] |
| |
| cols = st.columns(2) |
| for i, name in enumerate(ordered): |
| path = plot_paths[name] |
| col = cols[i % 2] |
| if Path(path).exists(): |
| col.image( |
| path, |
| caption=name.replace("_", " ").title(), |
| use_container_width=True, |
| ) |
| else: |
| col.warning(f"Plot not found: {path}") |
| |
| _render_shap_ui(plot_paths) |
| |
| |
| with tab_features: |
| fi = result.get("feature_importances", {}) |
| if not fi: |
| st.info("Feature importances not available for this model.") |
| else: |
| top = list(fi.items())[:20] |
| max_val = top[0][1] if top else 1.0 |
| |
| fp = _pal() |
| st.markdown( |
| f'<div style="font-family:\'JetBrains Mono\',monospace;' |
| f'font-size:11px;color:{fp["muted"]};margin-bottom:16px;">' |
| 'Normalized importance scores</div>', |
| unsafe_allow_html=True, |
| ) |
| for feat, imp in top: |
| bar_pct = int((imp / max_val) * 100) |
| color = fp["accent_soft"] if imp == max_val else fp["green"] |
| st.markdown( |
| f'<div style="margin-bottom:10px;">' |
| f'<div style="display:flex;justify-content:space-between;' |
| f'font-family:\'JetBrains Mono\',monospace;font-size:12px;' |
| f'color:{fp["text"]};margin-bottom:4px;">' |
| f'<span>{feat}</span>' |
| f'<span style="color:{fp["muted"]}">{imp:.4f}</span></div>' |
| f'<div class="fi-bar-bg"><div class="fi-bar-fill" ' |
| f'style="width:{bar_pct}%;background:{color};"></div></div>' |
| f'</div>', |
| unsafe_allow_html=True, |
| ) |
| |
| |
| with tab_data: |
| eda = result.get("eda", {}) |
| if not eda: |
| st.info("EDA data not available.") |
| else: |
| ov = eda.get("overview", {}) |
| c1, c2, c3, c4 = st.columns(4) |
| c1.metric("Rows", f"{ov.get('rows', 0):,}") |
| c2.metric("Columns", ov.get("columns", 0)) |
| c3.metric("Numeric", ov.get("numeric_cols", 0)) |
| c4.metric("Nulls", eda.get("missing", {}).get("total_missing", 0)) |
| |
| |
| col_profiles = eda.get("columns", {}) |
| if col_profiles: |
| rows = [] |
| for col_name, prof in col_profiles.items(): |
| miss = prof.get("missing", 0) |
| n_rows = ov.get("rows", 1) |
| miss_p = f"{miss / n_rows * 100:.1f}%" if n_rows else "—" |
| if prof["dtype_group"] == "numeric": |
| rows.append({ |
| "Column": col_name, |
| "Type": "numeric", |
| "Missing": miss_p, |
| "Mean": f"{prof.get('mean', 0):.2f}" if prof.get("mean") is not None else "—", |
| "Std": f"{prof.get('std', 0):.2f}" if prof.get("std") is not None else "—", |
| "Min": f"{prof.get('min', 0):.2f}" if prof.get("min") is not None else "—", |
| "Max": f"{prof.get('max', 0):.2f}" if prof.get("max") is not None else "—", |
| }) |
| else: |
| rows.append({ |
| "Column": col_name, |
| "Type": "categorical", |
| "Missing": miss_p, |
| "Mean": "—", |
| "Std": "—", |
| "Min": f"{prof.get('n_unique', 0)} unique", |
| "Max": f"top: {prof.get('top_value', '—')}", |
| }) |
| st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) |
| |
| |
| flags = eda.get("quality_flags", []) |
| if flags: |
| st.markdown("**Quality flags**") |
| for f in flags: |
| st.warning(f, icon="⚠️") |
| |
| |
| recs = eda.get("recommendations", []) |
| if recs: |
| st.markdown("**Preprocessing applied**") |
| for r in recs: |
| st.markdown(f"- {r}") |
| |
| |
| |
| elif not st.session_state.get("running", False) and not st.session_state.get("log_lines", []): |
| ep = _pal() |
| st.markdown( |
| f""" |
| <div style="text-align:center;padding:60px 0 40px 0;"> |
| <div style="font-family:'JetBrains Mono',monospace;font-size:48px; |
| color:{ep['empty_icon']};margin-bottom:16px;">⚡</div> |
| <div style="font-family:'JetBrains Mono',monospace;font-size:15px; |
| color:{ep['empty_sub']};margin-bottom:8px;">Drop a CSV. Describe your goal. Run.</div> |
| <div style="font-family:'DM Sans',sans-serif;font-size:13px;color:{ep['empty_body']};"> |
| Upload a CSV above (or in the sidebar) → describe your goal → Run |
| </div> |
| </div> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
| with tab_inference: |
| _render_inference_tab() |
|
|