""" SQL Agent — Gradio app for Hugging Face Spaces (ZeroGPU). Apple x Claude minimalist design with progressive feedback during the multi-step pipeline. """ import logging import os import sys from pathlib import Path from typing import Generator, Optional, Tuple import pandas as pd ROOT = Path(__file__).parent sys.path.insert(0, str(ROOT)) logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") logger = logging.getLogger(__name__) import gradio as gr # noqa: E402 try: import spaces # type: ignore HAS_SPACES = True except ImportError: HAS_SPACES = False class _SpacesShim: @staticmethod def GPU(duration: int = 60): # noqa: N802 def decorator(fn): return fn return decorator spaces = _SpacesShim() # type: ignore # CRITICAL: load all 3 models on cuda at module level per ZeroGPU best # practice. PyTorch CUDA emulation handles this when no real GPU is present; # inside @spaces.GPU calls, the real GPU is used and inference is fast. logger.info("Loading models at module level...") from src.models.sql_generator import SQLGenerator # noqa: E402 from src.models.chart_reasoner import ChartReasoner # noqa: E402 from src.models.svg_renderer import SVGRenderer # noqa: E402 from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402 _SQL_GEN = SQLGenerator() _CHART_REASONER = ChartReasoner() _SVG_RENDERER = SVGRenderer() logger.info("All models loaded") # ============================================================ THEME / CSS THEME_CSS = """ :root { --ink: #0E0E0E; --ink-muted: #5A5A5A; --ink-faint: #E5E5E5; --surface: #FAFAF9; --surface-raised: #FFFFFF; --accent: #C96442; --accent-soft: rgba(201, 100, 66, 0.08); --radius: 16px; --radius-sm: 10px; --shadow-sm: 0 1px 2px rgba(0,0,0,0.04); --shadow-md: 0 6px 24px rgba(0,0,0,0.08); --font: -apple-system, BlinkMacSystemFont, "SF Pro Text", "SF Pro Display", "Helvetica Neue", Arial, sans-serif; --font-mono: "SF Mono", ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; } @media (prefers-color-scheme: dark) { :root { --ink: #F4F4F2; --ink-muted: #8A8A8A; --ink-faint: #2A2A2A; --surface: #0E0E0E; --surface-raised: #161616; --accent: #E8866A; --accent-soft: rgba(232, 134, 106, 0.10); --shadow-sm: 0 1px 2px rgba(0,0,0,0.25); --shadow-md: 0 6px 24px rgba(0,0,0,0.45); } } /* Gradio container reset (Gradio 5 selectors) */ html, body, gradio-app, .gradio-container, .main, .app, .contain, .wrap, gradio-app > .main, gradio-app .gradio-container { background: var(--surface) !important; color: var(--ink) !important; font-family: var(--font) !important; } /* Fluid container — no max-width, fills the entire iframe to remove the dark side bars that appeared when content was capped at 1600px. */ .gradio-container { max-width: none !important; width: 100% !important; margin: 0 !important; padding: 28px 48px 48px !important; min-height: 100vh !important; box-sizing: border-box; } /* Two-column rectangular layout */ .split-layout { display: grid; grid-template-columns: minmax(480px, 580px) 1fr; gap: 36px; align-items: stretch; min-height: 720px; } @media (max-width: 900px) { .split-layout { grid-template-columns: 1fr; gap: 22px; min-height: auto; } .gradio-container { padding: 24px 18px 60px !important; min-height: auto !important; } } .split-left { position: sticky; top: 24px; } .split-right { min-height: 720px; display: flex; flex-direction: column; } .split-right > .panel-label { flex-shrink: 0; } .panel-label { font-size: 11px; font-weight: 600; text-transform: uppercase; letter-spacing: 0.06em; color: var(--ink-muted) !important; margin: 0 0 10px 2px; } footer { display: none !important; } .show-api { display: none !important; } /* Header */ .app-header { margin-bottom: 28px; padding-bottom: 20px; border-bottom: 1px solid var(--ink-faint); display: flex; align-items: baseline; justify-content: space-between; gap: 16px; } .app-title { font-size: 18px; font-weight: 600; letter-spacing: -0.015em; color: var(--ink) !important; } .app-subtitle { font-size: 13px; color: var(--ink-muted) !important; } /* HF Login button — make it Apple-style */ [data-testid="login-button"], .login-button, button[aria-label*="login"], button[aria-label*="Login"] { background: var(--ink) !important; color: var(--surface) !important; border: none !important; border-radius: var(--radius-sm) !important; font-weight: 500 !important; padding: 8px 14px !important; font-size: 13px !important; margin-bottom: 18px !important; } /* File upload — compact, Apple-style */ .upload-row { margin-bottom: 14px; } .upload-row .gr-file, .upload-row .file-preview { background: transparent !important; } .upload-row [data-testid="file"] { border: 1.5px dashed var(--ink-faint) !important; border-radius: var(--radius) !important; padding: 20px 16px !important; background: transparent !important; transition: all 200ms ease !important; min-height: 90px !important; } .upload-row [data-testid="file"]:hover, .upload-row [data-testid="file"].drag-over, .upload-row [data-testid="file"][data-drag-over="true"] { border-color: var(--accent) !important; background: var(--accent-soft) !important; transform: scale(1.005); } .upload-row [data-testid="file"] { cursor: pointer !important; } .upload-row [data-testid="file"] *, .upload-row .upload-text, .upload-row svg { color: var(--ink-muted) !important; fill: var(--ink-muted) !important; } .upload-row .file-preview, .upload-row [class*="FilePreview"] { display: none !important; } /* Hide the gr.File "X clear" button (we have our own Clear button) */ .upload-row [class*="Remove"], .upload-row button[aria-label*="lear"], .upload-row button[aria-label*="emove"], .upload-row .remove { display: none !important; } .upload-row .file { padding: 0 !important; background: transparent !important; } .upload-row [data-testid="file"] .download-link { display: none !important; } /* File chip (after upload) */ .file-chip { display: inline-flex; align-items: center; gap: 10px; padding: 8px 14px 8px 12px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: 999px; font-size: 13px; color: var(--ink) !important; } .file-chip span { color: var(--ink) !important; } .file-chip-meta { color: var(--ink-muted) !important; font-size: 12px; } .file-chip-dot { width: 6px; height: 6px; background: var(--accent); border-radius: 50%; flex-shrink: 0; } /* Hide the giant gr.File "uploaded file" display — we have our own chip */ .upload-row [data-testid="file"] .file-preview, .upload-row .file-preview-holder { display: none !important; } /* Question input */ .question-row { margin: 14px 0 8px; } textarea, .gr-text-input textarea, [data-testid="textbox"] textarea { background: var(--surface-raised) !important; border: 1px solid var(--ink-faint) !important; border-radius: var(--radius-sm) !important; color: var(--ink) !important; font-family: var(--font) !important; font-size: 15px !important; padding: 14px 16px !important; box-shadow: none !important; transition: border-color 150ms ease !important; line-height: 1.5 !important; } textarea:focus, [data-testid="textbox"] textarea:focus { border-color: var(--accent) !important; outline: none !important; box-shadow: 0 0 0 3px var(--accent-soft) !important; } textarea::placeholder { color: var(--ink-muted) !important; } .kb-hint { font-size: 11px; color: var(--ink-muted) !important; margin: 4px 4px 0; background: transparent !important; padding: 0 !important; } /* Question group container — make sure it's not boxed/dark */ .question-row, .question-row > div, .question-row .gr-block, .question-row .gr-form { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; } /* Button row: keep all buttons in a single line, give Ask more weight */ .btn-row { gap: 8px !important; flex-wrap: nowrap !important; margin-top: 12px; } .btn-row > * { min-width: 0 !important; } /* Buttons */ button.primary, button[variant="primary"], .gr-button.primary { background: var(--ink) !important; color: var(--surface) !important; border: none !important; border-radius: var(--radius-sm) !important; font-family: var(--font) !important; font-weight: 500 !important; font-size: 14px !important; padding: 10px 18px !important; transition: opacity 150ms ease !important; box-shadow: none !important; } button.primary:hover { opacity: 0.85 !important; } button.secondary, button[variant="secondary"] { background: transparent !important; color: var(--ink) !important; border: 1px solid var(--ink-faint) !important; border-radius: var(--radius-sm) !important; font-weight: 500 !important; padding: 10px 18px !important; } /* Conversation */ .turn { margin: 32px 0; } .turn:first-child { margin-top: 16px; } .turn-question { font-size: 16px; color: var(--ink); font-weight: 500; margin-bottom: 14px; letter-spacing: -0.01em; line-height: 1.5; } .turn-progress { display: flex; align-items: center; gap: 10px; font-size: 13px; color: var(--ink-muted); padding: 12px 16px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); margin: 6px 0; } .turn-progress::before { content: ""; width: 8px; height: 8px; background: var(--accent); border-radius: 50%; animation: pulse 1.2s ease-in-out infinite; } @keyframes pulse { 0%, 100% { opacity: 0.3; transform: scale(1); } 50% { opacity: 1; transform: scale(1.3); } } /* Pipeline stages — 4 dots with state */ .pipeline { display: flex; flex-direction: column; gap: 6px; padding: 16px 18px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); margin: 6px 0; } .pipeline-step { display: flex; align-items: center; gap: 10px; font-size: 13px; color: var(--ink-muted) !important; transition: color 200ms ease; } .pipeline-step.done .pipeline-dot { background: var(--accent); border-color: var(--accent); } .pipeline-step.done .pipeline-label { color: var(--ink) !important; } .pipeline-step.active .pipeline-dot { background: var(--accent); border-color: var(--accent); animation: pulse 1.2s ease-in-out infinite; } .pipeline-step.active .pipeline-label { color: var(--ink) !important; font-weight: 500; } .pipeline-step.pending .pipeline-dot { background: transparent; border-color: var(--ink-faint); } .pipeline-dot { width: 8px; height: 8px; border-radius: 50%; border: 1.5px solid var(--ink-faint); flex-shrink: 0; } .pipeline-label { font-size: 13px; } /* Polished empty state — large icon for big empty panel */ .empty-large { padding: 80px 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 480px; } .empty-icon { color: var(--ink-muted); margin-bottom: 22px; animation: fadeIn 400ms ease-out; } .empty-large .empty-title { font-size: 16px; margin-bottom: 8px; } .empty-large .empty-sub { max-width: 360px; } /* Subtle entrance animations */ @keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } } @keyframes fadeInUp { from { opacity: 0; transform: translateY(8px); } to { opacity: 1; transform: translateY(0); } } .turn { animation: fadeInUp 250ms ease-out; } .chart-wrap { animation: fadeInUp 350ms ease-out 50ms backwards; } .narration { animation: fadeInUp 350ms ease-out 150ms backwards; } .downloads { animation: fadeInUp 350ms ease-out 200ms backwards; } .suggestion-chip { animation: fadeInUp 250ms ease-out backwards; } .suggestions .suggestion-chip:nth-child(1) { animation-delay: 50ms; } .suggestions .suggestion-chip:nth-child(2) { animation-delay: 100ms; } .suggestions .suggestion-chip:nth-child(3) { animation-delay: 150ms; } .suggestions .suggestion-chip:nth-child(4) { animation-delay: 200ms; } .schema-col { animation: fadeIn 200ms ease-out backwards; } .turn-error { background: var(--accent-soft); border-left: 3px solid var(--accent); color: var(--accent); padding: 12px 14px; border-radius: var(--radius-sm); font-size: 13px; margin: 6px 0; font-family: var(--font-mono); } .chart-wrap { background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius); padding: 24px; margin: 8px 0 14px; box-shadow: var(--shadow-sm); transition: box-shadow 250ms ease, transform 250ms ease; position: relative; max-width: 880px; margin-left: auto; margin-right: auto; } .chart-wrap:hover { box-shadow: var(--shadow-md); } .chart-wrap svg { width: 100% !important; height: auto !important; max-height: 460px !important; display: block; margin: 0 auto; } .chart-badge { position: absolute; top: 12px; right: 16px; font-size: 10px; letter-spacing: 0.06em; font-weight: 600; text-transform: uppercase; color: var(--ink-muted); background: var(--surface); padding: 4px 8px; border-radius: 6px; border: 1px solid var(--ink-faint); } /* Code blocks with syntax highlighting + copy button */ .sql-wrap { position: relative; margin: 6px 0 0; } .sql-block { background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); font-family: var(--font-mono); font-size: 12.5px; color: var(--ink); padding: 14px 16px; overflow-x: auto; white-space: pre-wrap; margin: 0; line-height: 1.65; } .sql-kw { color: var(--accent); font-weight: 600; } .sql-fn { color: var(--ink); font-weight: 500; } .sql-str { color: #6b7d4f; } .sql-num { color: #7d6b4f; } .sql-cmt { color: var(--ink-muted); font-style: italic; } .sql-copy { position: absolute; top: 8px; right: 8px; font-size: 11px; padding: 4px 9px; background: var(--surface); border: 1px solid var(--ink-faint); border-radius: 6px; color: var(--ink-muted); cursor: pointer; transition: all 150ms ease; font-family: var(--font); user-select: none; } .sql-copy:hover { color: var(--ink); border-color: var(--accent); } .sql-copy:active { background: var(--accent-soft); } /* Details / collapsibles */ details { margin: 8px 0; border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); background: var(--surface-raised); } details summary { cursor: pointer; padding: 10px 14px; font-size: 12.5px; color: var(--ink-muted); list-style: none; user-select: none; font-weight: 500; } details summary::-webkit-details-marker { display: none; } details summary::before { content: "›"; display: inline-block; width: 12px; margin-right: 4px; transition: transform 150ms ease; color: var(--ink-muted); } details[open] summary::before { transform: rotate(90deg); } details > *:not(summary) { padding: 0 14px 14px; } /* Data table */ .data-table { width: 100%; border-collapse: collapse; font-size: 12px; font-family: var(--font); } .data-table th { text-align: left; font-weight: 600; color: var(--ink); padding: 8px 10px; border-bottom: 1px solid var(--ink-faint); white-space: nowrap; } .data-table td { padding: 7px 10px; color: var(--ink-muted); border-bottom: 1px solid var(--ink-faint); } .data-table tr:last-child td { border-bottom: none; } .data-table-meta { font-size: 11px; color: var(--ink-muted); margin-top: 8px; padding: 0 4px; } /* Empty state */ .empty { padding: 40px 0 8px; text-align: center; } .empty-title { font-size: 15px; color: var(--ink); font-weight: 500; margin-bottom: 6px; } .empty-sub { font-size: 13px; color: var(--ink-muted); margin-bottom: 28px; } .example-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); gap: 10px; max-width: 580px; margin: 0 auto; } .example-card { text-align: left; padding: 14px 16px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); cursor: pointer; transition: all 150ms ease; } .example-card:hover { border-color: var(--accent); background: var(--accent-soft); transform: translateY(-1px); } .example-card-title { font-size: 13px; font-weight: 500; color: var(--ink); margin-bottom: 4px; } .example-card-meta { font-size: 11px; color: var(--ink-muted); } /* Suggestions */ .suggestions { display: flex; flex-wrap: wrap; gap: 6px; margin: 14px 0 0; } .suggestion-chip { font-size: 12px; padding: 6px 12px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: 999px; cursor: pointer; color: var(--ink-muted); transition: all 150ms ease; } .suggestion-chip:hover { border-color: var(--accent); color: var(--ink); } /* Hide labels Gradio adds */ .gr-form > label, label.svelte-1gfkn6j, .label-wrap { display: none !important; } /* Narration — analyst-style finding under the chart */ .narration { margin: 12px 2px 4px; font-size: 14px; line-height: 1.55; color: var(--ink) !important; letter-spacing: -0.005em; padding-left: 12px; border-left: 2px solid var(--accent); } /* Download links below chart */ .downloads { display: flex; gap: 6px; margin: 8px 0 4px; flex-wrap: wrap; } .download-link { display: inline-flex; align-items: center; gap: 6px; font-size: 12px; padding: 6px 12px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); color: var(--ink-muted); text-decoration: none; transition: all 150ms ease; cursor: pointer; } .download-link:hover { border-color: var(--accent); color: var(--ink); background: var(--accent-soft); } .download-link .icon { font-family: var(--font-mono); font-size: 13px; line-height: 1; } /* Schema preview after upload — capped height so it doesn't push the question off screen on wide schemas */ .schema-preview { margin: 14px 0; padding: 14px 16px; background: var(--surface-raised); border: 1px solid var(--ink-faint); border-radius: var(--radius-sm); max-height: 220px; overflow-y: auto; } .schema-preview::-webkit-scrollbar { width: 6px; } .schema-preview::-webkit-scrollbar-thumb { background: var(--ink-faint); border-radius: 3px; } .schema-preview-header { font-size: 12px; font-weight: 600; color: var(--ink); text-transform: uppercase; letter-spacing: 0.04em; margin-bottom: 10px; } .schema-cols { display: flex; flex-wrap: wrap; gap: 6px; } .schema-col { display: inline-flex; align-items: baseline; gap: 5px; padding: 4px 10px; background: var(--surface); border: 1px solid var(--ink-faint); border-radius: 6px; font-size: 12px; } .schema-col-name { color: var(--ink); font-family: var(--font-mono); font-size: 12px; } .schema-col-type { color: var(--ink-muted); font-size: 10px; text-transform: uppercase; letter-spacing: 0.04em; } """ # ===================================================== ORCHESTRATOR (lazy) _AGENT: Optional[SQLAgentOrchestrator] = None def get_agent() -> SQLAgentOrchestrator: global _AGENT if _AGENT is None: _AGENT = SQLAgentOrchestrator(_SQL_GEN, _CHART_REASONER, _SVG_RENDERER) return _AGENT # =================================================== EXAMPLE DATA (built-in) def _make_titanic_csv() -> Path: """Tiny embedded Titanic-like sample so first-time users can play with no upload.""" p = ROOT / "_examples" / "titanic.csv" if p.exists(): return p p.parent.mkdir(parents=True, exist_ok=True) df = pd.DataFrame({ "passenger_id": range(1, 21), "survived": [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1], "pclass": [3, 1, 3, 1, 3, 3, 1, 3, 3, 2, 3, 1, 3, 3, 3, 2, 3, 2, 3, 3], "sex": ["male","female","female","female","male","male","male","male","female","female", "female","female","male","male","female","female","male","female","male","female"], "age": [22,38,26,35,35,None,54,2,27,14,4,58,20,39,14,55,2,None,31,None], "fare": [7.25,71.28,7.92,53.10,8.05,8.46,51.86,21.07,11.13,30.07, 16.70,26.55,8.05,31.27,7.85,16.00,29.13,13.00,18.00,7.23], "embarked": ["S","C","S","S","S","Q","S","S","S","C","S","S","S","S","Q","S","Q","S","S","Q"], }) df.to_csv(p, index=False) return p def _suggest_questions(table: str, schema: list[dict]) -> list[str]: """Generate question suggestions tailored to the loaded dataset's columns.""" if not schema: return [] NUMERIC = {"INTEGER", "BIGINT", "DOUBLE", "FLOAT", "DECIMAL", "NUMERIC", "REAL", "INT", "SMALLINT"} DATE = {"DATE", "TIMESTAMP", "DATETIME", "TIME"} STRING = {"VARCHAR", "STRING", "TEXT", "CHAR"} def kind(t: str) -> str: t = (t or "").upper().split("(")[0] if any(k in t for k in NUMERIC): return "num" if any(k in t for k in DATE): return "date" if any(k in t for k in STRING): return "str" return "other" cols = [(c["name"], kind(c.get("type", ""))) for c in schema] nums = [n for n, k in cols if k == "num"] dates = [n for n, k in cols if k == "date"] strs = [n for n, k in cols if k == "str"] qs: list[str] = [] if nums: qs.append(f"Top 10 rows by {nums[0]}") if strs and nums: qs.append(f"{nums[0].capitalize()} grouped by {strs[0]}") if strs: qs.append(f"Count of rows by {strs[0]}") if dates and nums: qs.append(f"{nums[0].capitalize()} over time ({dates[0]})") if len(nums) >= 2: qs.append(f"Compare {nums[0]} vs {nums[1]}") if not qs: qs.append(f"Show me the first 10 rows of {table}") return qs[:4] # =================================================== HTML render helpers def _file_chip_html(filename: str, rows: int, cols: int) -> str: return ( '
' '' f'{filename}' f'{rows:,} rows · {cols} cols' '
' ) def _schema_preview_html(table: str, schema: list[dict]) -> str: """Render the column names + types of the loaded table.""" if not schema: return "" cols = "".join( f'' f'{c["name"]}' f'{c["type"]}' f'' for c in schema ) return ( '
' f'
{table} · {len(schema)} columns
' f'
{cols}
' '
' ) def _download_links_html(sql: str, results: list[dict], svg: str) -> str: """Inline data-URL download links for CSV (results) and SVG (chart).""" import base64 import csv import io as _io parts = [] # CSV download if results: buf = _io.StringIO() writer = csv.DictWriter(buf, fieldnames=list(results[0].keys())) writer.writeheader() for r in results: writer.writerow({k: ("" if v is None else v) for k, v in r.items()}) csv_b64 = base64.b64encode(buf.getvalue().encode("utf-8")).decode("ascii") parts.append( f'' f' CSV ({len(results):,} rows)' ) # SVG download (standalone version: explicit dims, white bg, XML prolog) if svg and "' f' SVG' ) if not parts: return "" return f'
{"".join(parts)}
' def _suggestions_html(qs: list[str]) -> str: if not qs: return "" chips = "".join( f'{q}' for q in qs ) return f'
{chips}
' def _data_table_html(rows: list[dict], max_rows: int = 10) -> str: if not rows: return '
No rows.
' df = pd.DataFrame(rows[:max_rows]) cols = df.columns.tolist() head = "".join(f"{c}" for c in cols) body = "".join( "" + "".join( f"{('' if r.get(c) is None else r.get(c, ''))}" for c in cols ) + "" for r in rows[:max_rows] ) note = ( f'
Showing {min(max_rows, len(rows))} of {len(rows):,} rows
' if len(rows) > max_rows else "" ) return f'{head}{body}
{note}' PIPELINE_STAGES = [ ("sql", "Generating SQL"), ("execute", "Running query"), ("chart", "Designing chart"), ("render", "Rendering visualization"), ] def _turn_html_progress(question: str, current_stage: str = "sql") -> str: """Render a turn that's still in progress with multi-step pipeline view.""" items = [] seen_active = False for key, label in PIPELINE_STAGES: if key == current_stage: cls = "active" seen_active = True elif seen_active: cls = "pending" else: cls = "done" items.append( f'
' f'' f'{label}' f'
' ) return ( '
' f'
{question}
' f'
{"".join(items)}
' '
' ) SQL_KEYWORDS = { "SELECT", "FROM", "WHERE", "GROUP", "BY", "ORDER", "HAVING", "LIMIT", "OFFSET", "JOIN", "LEFT", "RIGHT", "INNER", "OUTER", "FULL", "ON", "AS", "AND", "OR", "NOT", "IN", "BETWEEN", "LIKE", "IS", "NULL", "DISTINCT", "WITH", "UNION", "ALL", "CASE", "WHEN", "THEN", "ELSE", "END", "ASC", "DESC", "INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE", "CREATE", "TABLE", "DROP", "ALTER", } SQL_FUNCTIONS = { "COUNT", "SUM", "AVG", "MIN", "MAX", "DATE", "EXTRACT", "CAST", "COALESCE", "ROUND", "ABS", "LENGTH", "LOWER", "UPPER", "SUBSTR", "TRIM", "REPLACE", "CONCAT", "STRFTIME", "DATEDIFF", "NOW", } def _highlight_sql(sql: str) -> str: """Tokenize and add span highlighting to a SQL query.""" import re as _re, html as _html out = [] i = 0 n = len(sql) while i < n: c = sql[i] # Comments: -- to end of line if c == "-" and i + 1 < n and sql[i + 1] == "-": j = sql.find("\n", i) if j == -1: j = n out.append(f'{_html.escape(sql[i:j])}') i = j # Strings elif c in ("'", '"'): quote = c j = i + 1 while j < n and sql[j] != quote: if sql[j] == "\\" and j + 1 < n: j += 2 else: j += 1 j = min(j + 1, n) out.append(f'{_html.escape(sql[i:j])}') i = j # Numbers elif c.isdigit(): j = i while j < n and (sql[j].isdigit() or sql[j] == "."): j += 1 out.append(f'{sql[i:j]}') i = j # Identifiers / keywords / functions elif c.isalpha() or c == "_": j = i while j < n and (sql[j].isalnum() or sql[j] == "_"): j += 1 word = sql[i:j] upper = word.upper() if upper in SQL_KEYWORDS: out.append(f'{word}') elif upper in SQL_FUNCTIONS: out.append(f'{word}') else: out.append(_html.escape(word)) i = j else: out.append(_html.escape(c)) i += 1 return "".join(out) def _chart_badge(spec: dict, results_count: int) -> str: """Small uppercase badge showing chart type and row count.""" if not spec: return "" chart_type = (spec.get("chart_type") or "?").upper() if results_count > 30: return f'
{chart_type} · top 30 of {results_count}
' return f'
{chart_type} · {results_count} rows
' def _turn_html_complete(result: dict) -> str: """Render a finished turn.""" parts: list[str] = [f'
{result["question"]}
'] if result.get("error"): parts.append(f'
{result["error"]}
') if result.get("svg"): badge = _chart_badge(result.get("chart_spec"), len(result.get("results") or [])) parts.append( f'
{badge}{result["svg"]}
' ) # Narration: 1-2 sentence finding from the analyst persona if result.get("narration"): parts.append(f'
{result["narration"]}
') # Inline download links for CSV + SVG parts.append(_download_links_html( result.get("sql") or "", result.get("results") or [], result.get("svg") or "", )) if result.get("sql"): sql_safe = _highlight_sql(result["sql"]) sql_raw = result["sql"].replace("`", "\\`").replace("\\", "\\\\") copy_js = ( f"navigator.clipboard.writeText(`{sql_raw}`);" "this.textContent='Copied';setTimeout(()=>{this.textContent='Copy'},1200);" ) parts.append( '
SQL query' '
' f'' f'
{sql_safe}
' '
' '
' ) if result.get("results"): parts.append( '
Data' f'{_data_table_html(result["results"])}' '
' ) return f'
{"".join(parts)}
' def _conversation_html(history: list[dict], in_progress: tuple[str, str] | None = None) -> str: """Conversation HTML reacts to: data loaded? history? in-progress? in_progress is a tuple of (question, current_stage_key).""" has_data = bool(get_agent().list_tables()) has_turns = bool(history) or in_progress is not None if not has_data and not has_turns: return _empty_state_html() if has_data and not has_turns: return _ready_state_html() out = "".join(_turn_html_complete(t) for t in history) if in_progress: out += _turn_html_progress(in_progress[0], in_progress[1]) return out def _empty_state_html() -> str: """No data loaded — polished placeholder for the right panel.""" icon = ( '' '' '' '' '' '' '' '' ) return ( '
' f'
{icon}
' '
No data loaded
' '
Upload a CSV, JSON, Parquet or Excel file ' 'on the left, or click Demo to try sample data.
' '
' ) def _ready_state_html() -> str: """Shown when data is loaded but no queries asked yet. Suggestions are derived from the actual loaded table's columns.""" agent = get_agent() tables = agent.list_tables() suggestions: list[str] = [] if tables: schema = agent.executor.get_table_schema(tables[0]) suggestions = _suggest_questions(tables[0], schema) return ( '
' '
Ready
' '
Ask a question above, or try one of these:
' f'{_suggestions_html(suggestions)}' '
' ) # ============================================================ EVENT HANDLERS def _build_schema_html(table: str) -> str: agent = get_agent() schema = agent.executor.get_table_schema(table) return _schema_preview_html(table, schema) def _placeholder_for_table(table: str) -> str: """Return a smart placeholder based on the loaded table's columns.""" try: agent = get_agent() if not agent.list_tables(): return "Ask anything about your data…" schema = agent.executor.get_table_schema(table) suggestions = _suggest_questions(table, schema) if suggestions: return f"Try: {suggestions[0]}" except Exception: pass return "Ask anything about your data…" def on_upload(file): """Returns: (chip, schema_html, conversation, history, question_update).""" if file is None: return "", "", _conversation_html([]), [], gr.update(placeholder="Ask anything about your data…") agent = get_agent() agent.reset() try: path = Path(file.name if hasattr(file, "name") else file) table = agent.load_data(path) rows = agent.executor.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0] cols = len(agent.executor.get_table_schema(table)) chip = _file_chip_html(path.name, rows, cols) schema = _build_schema_html(table) placeholder = _placeholder_for_table(table) return chip, schema, _conversation_html([]), [], gr.update(placeholder=placeholder) except Exception as e: logger.exception("upload failed") return "", "", f'
Could not load file: {e}
', [], gr.update() def on_load_demo(): """Returns: (chip, schema_html, conversation, history, question_update).""" agent = get_agent() agent.reset() try: p = _make_titanic_csv() table = agent.load_data(p) rows = agent.executor.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0] cols = len(agent.executor.get_table_schema(table)) chip = _file_chip_html("titanic.csv (demo)", rows, cols) schema = _build_schema_html(table) placeholder = _placeholder_for_table(table) return chip, schema, _conversation_html([]), [], gr.update(placeholder=placeholder) except Exception as e: logger.exception("demo load failed") return "", "", f'
Could not load demo: {e}
', [], gr.update() @spaces.GPU(duration=60) def _gpu_process(question: str) -> dict: """Inference only — models already on cuda from module-level loading.""" agent = get_agent() return agent.process(question) def on_ask(question: str, history: list) -> Generator[Tuple[str, str, list], None, None]: """ Generator: yields conversation HTML at each pipeline step so the user sees real-time progress instead of waiting silently. """ # Single-shot mode: each Ask is a fresh query, replaces any previous result question = (question or "").strip() if not question: yield _conversation_html([]), "", [] return if not get_agent().list_tables(): result = { "question": question, "error": "Upload a file first or load the demo dataset.", } yield _conversation_html([result]), "", [result] return # Stream pipeline stages so user sees progress (4 visual steps). # The actual GPU call is one shot (yielding mid-call would break the # @spaces.GPU window) but we still surface the stages around it. yield _conversation_html([], in_progress=(question, "sql")), "", [] try: result = _gpu_process(question) except Exception as e: logger.exception("ask failed") result = {"question": question, "error": str(e)} yield _conversation_html([result]), "", [result] def on_reset(): """Clear ONLY the dataset (file, chip, schema, agent state). Also wipes the displayed result since it's now stale without data. Order: upload, chip_html, schema_html, conversation, history_state """ get_agent().reset() return None, "", "", _conversation_html([]), [] # ====================================================================== APP def build_app() -> gr.Blocks: with gr.Blocks( theme=gr.themes.Base(), css=THEME_CSS, title="SQL Agent", analytics_enabled=False, ) as demo: # Header gr.HTML( '
' '
' '
SQL Agent
' '
Ask anything about your data.
' '
' '
' ) # Two-column layout: LEFT = data/question controls, RIGHT = results with gr.Row(elem_classes=["split-layout"]): # ---------- LEFT panel ---------- with gr.Column(elem_classes=["split-left"], scale=0): gr.HTML('
Dataset
') with gr.Row(elem_classes=["upload-row"]): upload = gr.File( label="", file_types=[".csv", ".json", ".parquet", ".xlsx", ".xls"], show_label=False, container=False, ) chip_html = gr.HTML("") schema_html = gr.HTML("") gr.HTML('
Question
') with gr.Group(elem_classes=["question-row"]): question = gr.Textbox( placeholder="Ask anything about your data…", lines=3, max_lines=10, show_label=False, container=False, autofocus=True, ) gr.HTML('
Press Enter to send · Shift+Enter for newline
') with gr.Row(elem_classes=["btn-row"]): ask_btn = gr.Button("Ask", variant="primary", size="sm", scale=2) reset_btn = gr.Button("Clear", variant="secondary", size="sm", scale=1) demo_btn = gr.Button("Demo", variant="secondary", size="sm", elem_id="load_demo_btn", scale=1) # ---------- RIGHT panel ---------- with gr.Column(elem_classes=["split-right"], scale=1): gr.HTML('
Result
') history_state = gr.State([]) conversation = gr.HTML(_conversation_html([])) # ------------- events ------------- upload.upload( fn=on_upload, inputs=upload, outputs=[chip_html, schema_html, conversation, history_state, question], api_name=False, ) upload.clear( fn=lambda: ("", "", _conversation_html([]), [], gr.update(placeholder="Ask anything about your data…")), outputs=[chip_html, schema_html, conversation, history_state, question], api_name=False, ) demo_btn.click( fn=on_load_demo, outputs=[chip_html, schema_html, conversation, history_state, question], api_name=False, ) ask_btn.click( fn=on_ask, inputs=[question, history_state], outputs=[conversation, question, history_state], api_name=False, ) question.submit( fn=on_ask, inputs=[question, history_state], outputs=[conversation, question, history_state], api_name=False, ) reset_btn.click( fn=on_reset, outputs=[upload, chip_html, schema_html, conversation, history_state], api_name=False, ) return demo if __name__ == "__main__": app = build_app() app.queue(api_open=False).launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), show_api=False, )