Spaces:
Running on Zero
Running on Zero
Cap chart size: max-width 880px and max-height 460px so charts have natural proportions
180e51f verified | """ | |
| SQL Agent — Gradio app for Hugging Face Spaces (ZeroGPU). | |
| Apple x Claude minimalist design with progressive feedback during the | |
| multi-step pipeline. | |
| """ | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Generator, Optional, Tuple | |
| import pandas as pd | |
| ROOT = Path(__file__).parent | |
| sys.path.insert(0, str(ROOT)) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") | |
| logger = logging.getLogger(__name__) | |
| import gradio as gr # noqa: E402 | |
| try: | |
| import spaces # type: ignore | |
| HAS_SPACES = True | |
| except ImportError: | |
| HAS_SPACES = False | |
| class _SpacesShim: | |
| def GPU(duration: int = 60): # noqa: N802 | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesShim() # type: ignore | |
| # CRITICAL: load all 3 models on cuda at module level per ZeroGPU best | |
| # practice. PyTorch CUDA emulation handles this when no real GPU is present; | |
| # inside @spaces.GPU calls, the real GPU is used and inference is fast. | |
| logger.info("Loading models at module level...") | |
| from src.models.sql_generator import SQLGenerator # noqa: E402 | |
| from src.models.chart_reasoner import ChartReasoner # noqa: E402 | |
| from src.models.svg_renderer import SVGRenderer # noqa: E402 | |
| from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402 | |
| _SQL_GEN = SQLGenerator() | |
| _CHART_REASONER = ChartReasoner() | |
| _SVG_RENDERER = SVGRenderer() | |
| logger.info("All models loaded") | |
| # ============================================================ THEME / CSS | |
| THEME_CSS = """ | |
| :root { | |
| --ink: #0E0E0E; | |
| --ink-muted: #5A5A5A; | |
| --ink-faint: #E5E5E5; | |
| --surface: #FAFAF9; | |
| --surface-raised: #FFFFFF; | |
| --accent: #C96442; | |
| --accent-soft: rgba(201, 100, 66, 0.08); | |
| --radius: 16px; | |
| --radius-sm: 10px; | |
| --shadow-sm: 0 1px 2px rgba(0,0,0,0.04); | |
| --shadow-md: 0 6px 24px rgba(0,0,0,0.08); | |
| --font: -apple-system, BlinkMacSystemFont, "SF Pro Text", "SF Pro Display", | |
| "Helvetica Neue", Arial, sans-serif; | |
| --font-mono: "SF Mono", ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; | |
| } | |
| @media (prefers-color-scheme: dark) { | |
| :root { | |
| --ink: #F4F4F2; | |
| --ink-muted: #8A8A8A; | |
| --ink-faint: #2A2A2A; | |
| --surface: #0E0E0E; | |
| --surface-raised: #161616; | |
| --accent: #E8866A; | |
| --accent-soft: rgba(232, 134, 106, 0.10); | |
| --shadow-sm: 0 1px 2px rgba(0,0,0,0.25); | |
| --shadow-md: 0 6px 24px rgba(0,0,0,0.45); | |
| } | |
| } | |
| /* Gradio container reset (Gradio 5 selectors) */ | |
| html, body, gradio-app, .gradio-container, .main, .app, .contain, .wrap, | |
| gradio-app > .main, gradio-app .gradio-container { | |
| background: var(--surface) !important; | |
| color: var(--ink) !important; | |
| font-family: var(--font) !important; | |
| } | |
| /* Fluid container — no max-width, fills the entire iframe to remove | |
| the dark side bars that appeared when content was capped at 1600px. */ | |
| .gradio-container { | |
| max-width: none !important; | |
| width: 100% !important; | |
| margin: 0 !important; | |
| padding: 28px 48px 48px !important; | |
| min-height: 100vh !important; | |
| box-sizing: border-box; | |
| } | |
| /* Two-column rectangular layout */ | |
| .split-layout { | |
| display: grid; | |
| grid-template-columns: minmax(480px, 580px) 1fr; | |
| gap: 36px; | |
| align-items: stretch; | |
| min-height: 720px; | |
| } | |
| @media (max-width: 900px) { | |
| .split-layout { grid-template-columns: 1fr; gap: 22px; min-height: auto; } | |
| .gradio-container { padding: 24px 18px 60px !important; min-height: auto !important; } | |
| } | |
| .split-left { position: sticky; top: 24px; } | |
| .split-right { min-height: 720px; display: flex; flex-direction: column; } | |
| .split-right > .panel-label { flex-shrink: 0; } | |
| .panel-label { | |
| font-size: 11px; | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| letter-spacing: 0.06em; | |
| color: var(--ink-muted) !important; | |
| margin: 0 0 10px 2px; | |
| } | |
| footer { display: none !important; } | |
| .show-api { display: none !important; } | |
| /* Header */ | |
| .app-header { | |
| margin-bottom: 28px; | |
| padding-bottom: 20px; | |
| border-bottom: 1px solid var(--ink-faint); | |
| display: flex; | |
| align-items: baseline; | |
| justify-content: space-between; | |
| gap: 16px; | |
| } | |
| .app-title { font-size: 18px; font-weight: 600; letter-spacing: -0.015em; color: var(--ink) !important; } | |
| .app-subtitle { font-size: 13px; color: var(--ink-muted) !important; } | |
| /* HF Login button — make it Apple-style */ | |
| [data-testid="login-button"], .login-button, button[aria-label*="login"], button[aria-label*="Login"] { | |
| background: var(--ink) !important; | |
| color: var(--surface) !important; | |
| border: none !important; | |
| border-radius: var(--radius-sm) !important; | |
| font-weight: 500 !important; | |
| padding: 8px 14px !important; | |
| font-size: 13px !important; | |
| margin-bottom: 18px !important; | |
| } | |
| /* File upload — compact, Apple-style */ | |
| .upload-row { margin-bottom: 14px; } | |
| .upload-row .gr-file, .upload-row .file-preview { background: transparent !important; } | |
| .upload-row [data-testid="file"] { | |
| border: 1.5px dashed var(--ink-faint) !important; | |
| border-radius: var(--radius) !important; | |
| padding: 20px 16px !important; | |
| background: transparent !important; | |
| transition: all 200ms ease !important; | |
| min-height: 90px !important; | |
| } | |
| .upload-row [data-testid="file"]:hover, | |
| .upload-row [data-testid="file"].drag-over, | |
| .upload-row [data-testid="file"][data-drag-over="true"] { | |
| border-color: var(--accent) !important; | |
| background: var(--accent-soft) !important; | |
| transform: scale(1.005); | |
| } | |
| .upload-row [data-testid="file"] { | |
| cursor: pointer !important; | |
| } | |
| .upload-row [data-testid="file"] *, | |
| .upload-row .upload-text, | |
| .upload-row svg { color: var(--ink-muted) !important; fill: var(--ink-muted) !important; } | |
| .upload-row .file-preview, .upload-row [class*="FilePreview"] { display: none !important; } | |
| /* Hide the gr.File "X clear" button (we have our own Clear button) */ | |
| .upload-row [class*="Remove"], .upload-row button[aria-label*="lear"], | |
| .upload-row button[aria-label*="emove"], .upload-row .remove { display: none !important; } | |
| .upload-row .file { padding: 0 !important; background: transparent !important; } | |
| .upload-row [data-testid="file"] .download-link { display: none !important; } | |
| /* File chip (after upload) */ | |
| .file-chip { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 10px; | |
| padding: 8px 14px 8px 12px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: 999px; | |
| font-size: 13px; | |
| color: var(--ink) !important; | |
| } | |
| .file-chip span { color: var(--ink) !important; } | |
| .file-chip-meta { color: var(--ink-muted) !important; font-size: 12px; } | |
| .file-chip-dot { | |
| width: 6px; height: 6px; | |
| background: var(--accent); | |
| border-radius: 50%; | |
| flex-shrink: 0; | |
| } | |
| /* Hide the giant gr.File "uploaded file" display — we have our own chip */ | |
| .upload-row [data-testid="file"] .file-preview, | |
| .upload-row .file-preview-holder { | |
| display: none !important; | |
| } | |
| /* Question input */ | |
| .question-row { margin: 14px 0 8px; } | |
| textarea, .gr-text-input textarea, [data-testid="textbox"] textarea { | |
| background: var(--surface-raised) !important; | |
| border: 1px solid var(--ink-faint) !important; | |
| border-radius: var(--radius-sm) !important; | |
| color: var(--ink) !important; | |
| font-family: var(--font) !important; | |
| font-size: 15px !important; | |
| padding: 14px 16px !important; | |
| box-shadow: none !important; | |
| transition: border-color 150ms ease !important; | |
| line-height: 1.5 !important; | |
| } | |
| textarea:focus, [data-testid="textbox"] textarea:focus { | |
| border-color: var(--accent) !important; | |
| outline: none !important; | |
| box-shadow: 0 0 0 3px var(--accent-soft) !important; | |
| } | |
| textarea::placeholder { color: var(--ink-muted) !important; } | |
| .kb-hint { | |
| font-size: 11px; | |
| color: var(--ink-muted) !important; | |
| margin: 4px 4px 0; | |
| background: transparent !important; | |
| padding: 0 !important; | |
| } | |
| /* Question group container — make sure it's not boxed/dark */ | |
| .question-row, .question-row > div, .question-row .gr-block, .question-row .gr-form { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| } | |
| /* Button row: keep all buttons in a single line, give Ask more weight */ | |
| .btn-row { gap: 8px !important; flex-wrap: nowrap !important; margin-top: 12px; } | |
| .btn-row > * { min-width: 0 !important; } | |
| /* Buttons */ | |
| button.primary, button[variant="primary"], .gr-button.primary { | |
| background: var(--ink) !important; | |
| color: var(--surface) !important; | |
| border: none !important; | |
| border-radius: var(--radius-sm) !important; | |
| font-family: var(--font) !important; | |
| font-weight: 500 !important; | |
| font-size: 14px !important; | |
| padding: 10px 18px !important; | |
| transition: opacity 150ms ease !important; | |
| box-shadow: none !important; | |
| } | |
| button.primary:hover { opacity: 0.85 !important; } | |
| button.secondary, button[variant="secondary"] { | |
| background: transparent !important; | |
| color: var(--ink) !important; | |
| border: 1px solid var(--ink-faint) !important; | |
| border-radius: var(--radius-sm) !important; | |
| font-weight: 500 !important; | |
| padding: 10px 18px !important; | |
| } | |
| /* Conversation */ | |
| .turn { margin: 32px 0; } | |
| .turn:first-child { margin-top: 16px; } | |
| .turn-question { | |
| font-size: 16px; | |
| color: var(--ink); | |
| font-weight: 500; | |
| margin-bottom: 14px; | |
| letter-spacing: -0.01em; | |
| line-height: 1.5; | |
| } | |
| .turn-progress { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| font-size: 13px; | |
| color: var(--ink-muted); | |
| padding: 12px 16px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| margin: 6px 0; | |
| } | |
| .turn-progress::before { | |
| content: ""; | |
| width: 8px; height: 8px; | |
| background: var(--accent); | |
| border-radius: 50%; | |
| animation: pulse 1.2s ease-in-out infinite; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 0.3; transform: scale(1); } | |
| 50% { opacity: 1; transform: scale(1.3); } | |
| } | |
| /* Pipeline stages — 4 dots with state */ | |
| .pipeline { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 6px; | |
| padding: 16px 18px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| margin: 6px 0; | |
| } | |
| .pipeline-step { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| font-size: 13px; | |
| color: var(--ink-muted) !important; | |
| transition: color 200ms ease; | |
| } | |
| .pipeline-step.done .pipeline-dot { | |
| background: var(--accent); | |
| border-color: var(--accent); | |
| } | |
| .pipeline-step.done .pipeline-label { | |
| color: var(--ink) !important; | |
| } | |
| .pipeline-step.active .pipeline-dot { | |
| background: var(--accent); | |
| border-color: var(--accent); | |
| animation: pulse 1.2s ease-in-out infinite; | |
| } | |
| .pipeline-step.active .pipeline-label { | |
| color: var(--ink) !important; | |
| font-weight: 500; | |
| } | |
| .pipeline-step.pending .pipeline-dot { | |
| background: transparent; | |
| border-color: var(--ink-faint); | |
| } | |
| .pipeline-dot { | |
| width: 8px; height: 8px; | |
| border-radius: 50%; | |
| border: 1.5px solid var(--ink-faint); | |
| flex-shrink: 0; | |
| } | |
| .pipeline-label { | |
| font-size: 13px; | |
| } | |
| /* Polished empty state — large icon for big empty panel */ | |
| .empty-large { | |
| padding: 80px 20px; | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| justify-content: center; | |
| min-height: 480px; | |
| } | |
| .empty-icon { | |
| color: var(--ink-muted); | |
| margin-bottom: 22px; | |
| animation: fadeIn 400ms ease-out; | |
| } | |
| .empty-large .empty-title { font-size: 16px; margin-bottom: 8px; } | |
| .empty-large .empty-sub { max-width: 360px; } | |
| /* Subtle entrance animations */ | |
| @keyframes fadeIn { | |
| from { opacity: 0; } | |
| to { opacity: 1; } | |
| } | |
| @keyframes fadeInUp { | |
| from { opacity: 0; transform: translateY(8px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| .turn { animation: fadeInUp 250ms ease-out; } | |
| .chart-wrap { animation: fadeInUp 350ms ease-out 50ms backwards; } | |
| .narration { animation: fadeInUp 350ms ease-out 150ms backwards; } | |
| .downloads { animation: fadeInUp 350ms ease-out 200ms backwards; } | |
| .suggestion-chip { | |
| animation: fadeInUp 250ms ease-out backwards; | |
| } | |
| .suggestions .suggestion-chip:nth-child(1) { animation-delay: 50ms; } | |
| .suggestions .suggestion-chip:nth-child(2) { animation-delay: 100ms; } | |
| .suggestions .suggestion-chip:nth-child(3) { animation-delay: 150ms; } | |
| .suggestions .suggestion-chip:nth-child(4) { animation-delay: 200ms; } | |
| .schema-col { animation: fadeIn 200ms ease-out backwards; } | |
| .turn-error { | |
| background: var(--accent-soft); | |
| border-left: 3px solid var(--accent); | |
| color: var(--accent); | |
| padding: 12px 14px; | |
| border-radius: var(--radius-sm); | |
| font-size: 13px; | |
| margin: 6px 0; | |
| font-family: var(--font-mono); | |
| } | |
| .chart-wrap { | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius); | |
| padding: 24px; | |
| margin: 8px 0 14px; | |
| box-shadow: var(--shadow-sm); | |
| transition: box-shadow 250ms ease, transform 250ms ease; | |
| position: relative; | |
| max-width: 880px; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| .chart-wrap:hover { box-shadow: var(--shadow-md); } | |
| .chart-wrap svg { | |
| width: 100% !important; | |
| height: auto !important; | |
| max-height: 460px !important; | |
| display: block; | |
| margin: 0 auto; | |
| } | |
| .chart-badge { | |
| position: absolute; | |
| top: 12px; | |
| right: 16px; | |
| font-size: 10px; | |
| letter-spacing: 0.06em; | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| color: var(--ink-muted); | |
| background: var(--surface); | |
| padding: 4px 8px; | |
| border-radius: 6px; | |
| border: 1px solid var(--ink-faint); | |
| } | |
| /* Code blocks with syntax highlighting + copy button */ | |
| .sql-wrap { position: relative; margin: 6px 0 0; } | |
| .sql-block { | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| font-family: var(--font-mono); | |
| font-size: 12.5px; | |
| color: var(--ink); | |
| padding: 14px 16px; | |
| overflow-x: auto; | |
| white-space: pre-wrap; | |
| margin: 0; | |
| line-height: 1.65; | |
| } | |
| .sql-kw { color: var(--accent); font-weight: 600; } | |
| .sql-fn { color: var(--ink); font-weight: 500; } | |
| .sql-str { color: #6b7d4f; } | |
| .sql-num { color: #7d6b4f; } | |
| .sql-cmt { color: var(--ink-muted); font-style: italic; } | |
| .sql-copy { | |
| position: absolute; | |
| top: 8px; | |
| right: 8px; | |
| font-size: 11px; | |
| padding: 4px 9px; | |
| background: var(--surface); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: 6px; | |
| color: var(--ink-muted); | |
| cursor: pointer; | |
| transition: all 150ms ease; | |
| font-family: var(--font); | |
| user-select: none; | |
| } | |
| .sql-copy:hover { color: var(--ink); border-color: var(--accent); } | |
| .sql-copy:active { background: var(--accent-soft); } | |
| /* Details / collapsibles */ | |
| details { | |
| margin: 8px 0; | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| background: var(--surface-raised); | |
| } | |
| details summary { | |
| cursor: pointer; | |
| padding: 10px 14px; | |
| font-size: 12.5px; | |
| color: var(--ink-muted); | |
| list-style: none; | |
| user-select: none; | |
| font-weight: 500; | |
| } | |
| details summary::-webkit-details-marker { display: none; } | |
| details summary::before { | |
| content: "›"; | |
| display: inline-block; | |
| width: 12px; | |
| margin-right: 4px; | |
| transition: transform 150ms ease; | |
| color: var(--ink-muted); | |
| } | |
| details[open] summary::before { transform: rotate(90deg); } | |
| details > *:not(summary) { padding: 0 14px 14px; } | |
| /* Data table */ | |
| .data-table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| font-size: 12px; | |
| font-family: var(--font); | |
| } | |
| .data-table th { | |
| text-align: left; | |
| font-weight: 600; | |
| color: var(--ink); | |
| padding: 8px 10px; | |
| border-bottom: 1px solid var(--ink-faint); | |
| white-space: nowrap; | |
| } | |
| .data-table td { | |
| padding: 7px 10px; | |
| color: var(--ink-muted); | |
| border-bottom: 1px solid var(--ink-faint); | |
| } | |
| .data-table tr:last-child td { border-bottom: none; } | |
| .data-table-meta { font-size: 11px; color: var(--ink-muted); margin-top: 8px; padding: 0 4px; } | |
| /* Empty state */ | |
| .empty { | |
| padding: 40px 0 8px; | |
| text-align: center; | |
| } | |
| .empty-title { | |
| font-size: 15px; | |
| color: var(--ink); | |
| font-weight: 500; | |
| margin-bottom: 6px; | |
| } | |
| .empty-sub { | |
| font-size: 13px; | |
| color: var(--ink-muted); | |
| margin-bottom: 28px; | |
| } | |
| .example-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); | |
| gap: 10px; | |
| max-width: 580px; | |
| margin: 0 auto; | |
| } | |
| .example-card { | |
| text-align: left; | |
| padding: 14px 16px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| cursor: pointer; | |
| transition: all 150ms ease; | |
| } | |
| .example-card:hover { | |
| border-color: var(--accent); | |
| background: var(--accent-soft); | |
| transform: translateY(-1px); | |
| } | |
| .example-card-title { | |
| font-size: 13px; | |
| font-weight: 500; | |
| color: var(--ink); | |
| margin-bottom: 4px; | |
| } | |
| .example-card-meta { | |
| font-size: 11px; | |
| color: var(--ink-muted); | |
| } | |
| /* Suggestions */ | |
| .suggestions { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 6px; | |
| margin: 14px 0 0; | |
| } | |
| .suggestion-chip { | |
| font-size: 12px; | |
| padding: 6px 12px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: 999px; | |
| cursor: pointer; | |
| color: var(--ink-muted); | |
| transition: all 150ms ease; | |
| } | |
| .suggestion-chip:hover { | |
| border-color: var(--accent); | |
| color: var(--ink); | |
| } | |
| /* Hide labels Gradio adds */ | |
| .gr-form > label, label.svelte-1gfkn6j, .label-wrap { display: none !important; } | |
| /* Narration — analyst-style finding under the chart */ | |
| .narration { | |
| margin: 12px 2px 4px; | |
| font-size: 14px; | |
| line-height: 1.55; | |
| color: var(--ink) !important; | |
| letter-spacing: -0.005em; | |
| padding-left: 12px; | |
| border-left: 2px solid var(--accent); | |
| } | |
| /* Download links below chart */ | |
| .downloads { | |
| display: flex; | |
| gap: 6px; | |
| margin: 8px 0 4px; | |
| flex-wrap: wrap; | |
| } | |
| .download-link { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| font-size: 12px; | |
| padding: 6px 12px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| color: var(--ink-muted); | |
| text-decoration: none; | |
| transition: all 150ms ease; | |
| cursor: pointer; | |
| } | |
| .download-link:hover { | |
| border-color: var(--accent); | |
| color: var(--ink); | |
| background: var(--accent-soft); | |
| } | |
| .download-link .icon { | |
| font-family: var(--font-mono); | |
| font-size: 13px; | |
| line-height: 1; | |
| } | |
| /* Schema preview after upload — capped height so it doesn't push | |
| the question off screen on wide schemas */ | |
| .schema-preview { | |
| margin: 14px 0; | |
| padding: 14px 16px; | |
| background: var(--surface-raised); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: var(--radius-sm); | |
| max-height: 220px; | |
| overflow-y: auto; | |
| } | |
| .schema-preview::-webkit-scrollbar { width: 6px; } | |
| .schema-preview::-webkit-scrollbar-thumb { | |
| background: var(--ink-faint); | |
| border-radius: 3px; | |
| } | |
| .schema-preview-header { | |
| font-size: 12px; | |
| font-weight: 600; | |
| color: var(--ink); | |
| text-transform: uppercase; | |
| letter-spacing: 0.04em; | |
| margin-bottom: 10px; | |
| } | |
| .schema-cols { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 6px; | |
| } | |
| .schema-col { | |
| display: inline-flex; | |
| align-items: baseline; | |
| gap: 5px; | |
| padding: 4px 10px; | |
| background: var(--surface); | |
| border: 1px solid var(--ink-faint); | |
| border-radius: 6px; | |
| font-size: 12px; | |
| } | |
| .schema-col-name { color: var(--ink); font-family: var(--font-mono); font-size: 12px; } | |
| .schema-col-type { color: var(--ink-muted); font-size: 10px; text-transform: uppercase; letter-spacing: 0.04em; } | |
| """ | |
| # ===================================================== ORCHESTRATOR (lazy) | |
| _AGENT: Optional[SQLAgentOrchestrator] = None | |
| def get_agent() -> SQLAgentOrchestrator: | |
| global _AGENT | |
| if _AGENT is None: | |
| _AGENT = SQLAgentOrchestrator(_SQL_GEN, _CHART_REASONER, _SVG_RENDERER) | |
| return _AGENT | |
| # =================================================== EXAMPLE DATA (built-in) | |
| def _make_titanic_csv() -> Path: | |
| """Tiny embedded Titanic-like sample so first-time users can play with no upload.""" | |
| p = ROOT / "_examples" / "titanic.csv" | |
| if p.exists(): | |
| return p | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| df = pd.DataFrame({ | |
| "passenger_id": range(1, 21), | |
| "survived": [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1], | |
| "pclass": [3, 1, 3, 1, 3, 3, 1, 3, 3, 2, 3, 1, 3, 3, 3, 2, 3, 2, 3, 3], | |
| "sex": ["male","female","female","female","male","male","male","male","female","female", | |
| "female","female","male","male","female","female","male","female","male","female"], | |
| "age": [22,38,26,35,35,None,54,2,27,14,4,58,20,39,14,55,2,None,31,None], | |
| "fare": [7.25,71.28,7.92,53.10,8.05,8.46,51.86,21.07,11.13,30.07, | |
| 16.70,26.55,8.05,31.27,7.85,16.00,29.13,13.00,18.00,7.23], | |
| "embarked": ["S","C","S","S","S","Q","S","S","S","C","S","S","S","S","Q","S","Q","S","S","Q"], | |
| }) | |
| df.to_csv(p, index=False) | |
| return p | |
| def _suggest_questions(table: str, schema: list[dict]) -> list[str]: | |
| """Generate question suggestions tailored to the loaded dataset's columns.""" | |
| if not schema: | |
| return [] | |
| NUMERIC = {"INTEGER", "BIGINT", "DOUBLE", "FLOAT", "DECIMAL", "NUMERIC", "REAL", "INT", "SMALLINT"} | |
| DATE = {"DATE", "TIMESTAMP", "DATETIME", "TIME"} | |
| STRING = {"VARCHAR", "STRING", "TEXT", "CHAR"} | |
| def kind(t: str) -> str: | |
| t = (t or "").upper().split("(")[0] | |
| if any(k in t for k in NUMERIC): return "num" | |
| if any(k in t for k in DATE): return "date" | |
| if any(k in t for k in STRING): return "str" | |
| return "other" | |
| cols = [(c["name"], kind(c.get("type", ""))) for c in schema] | |
| nums = [n for n, k in cols if k == "num"] | |
| dates = [n for n, k in cols if k == "date"] | |
| strs = [n for n, k in cols if k == "str"] | |
| qs: list[str] = [] | |
| if nums: | |
| qs.append(f"Top 10 rows by {nums[0]}") | |
| if strs and nums: | |
| qs.append(f"{nums[0].capitalize()} grouped by {strs[0]}") | |
| if strs: | |
| qs.append(f"Count of rows by {strs[0]}") | |
| if dates and nums: | |
| qs.append(f"{nums[0].capitalize()} over time ({dates[0]})") | |
| if len(nums) >= 2: | |
| qs.append(f"Compare {nums[0]} vs {nums[1]}") | |
| if not qs: | |
| qs.append(f"Show me the first 10 rows of {table}") | |
| return qs[:4] | |
| # =================================================== HTML render helpers | |
| def _file_chip_html(filename: str, rows: int, cols: int) -> str: | |
| return ( | |
| '<div class="file-chip">' | |
| '<span class="file-chip-dot"></span>' | |
| f'<span>{filename}</span>' | |
| f'<span class="file-chip-meta">{rows:,} rows · {cols} cols</span>' | |
| '</div>' | |
| ) | |
| def _schema_preview_html(table: str, schema: list[dict]) -> str: | |
| """Render the column names + types of the loaded table.""" | |
| if not schema: | |
| return "" | |
| cols = "".join( | |
| f'<span class="schema-col">' | |
| f'<span class="schema-col-name">{c["name"]}</span>' | |
| f'<span class="schema-col-type">{c["type"]}</span>' | |
| f'</span>' | |
| for c in schema | |
| ) | |
| return ( | |
| '<div class="schema-preview">' | |
| f'<div class="schema-preview-header">{table} · {len(schema)} columns</div>' | |
| f'<div class="schema-cols">{cols}</div>' | |
| '</div>' | |
| ) | |
| def _download_links_html(sql: str, results: list[dict], svg: str) -> str: | |
| """Inline data-URL download links for CSV (results) and SVG (chart).""" | |
| import base64 | |
| import csv | |
| import io as _io | |
| parts = [] | |
| # CSV download | |
| if results: | |
| buf = _io.StringIO() | |
| writer = csv.DictWriter(buf, fieldnames=list(results[0].keys())) | |
| writer.writeheader() | |
| for r in results: | |
| writer.writerow({k: ("" if v is None else v) for k, v in r.items()}) | |
| csv_b64 = base64.b64encode(buf.getvalue().encode("utf-8")).decode("ascii") | |
| parts.append( | |
| f'<a class="download-link" ' | |
| f'href="data:text/csv;base64,{csv_b64}" ' | |
| f'download="query-result.csv">' | |
| f'<span class="icon">↓</span> CSV ({len(results):,} rows)</a>' | |
| ) | |
| # SVG download (standalone version: explicit dims, white bg, XML prolog) | |
| if svg and "<svg" in svg.lower(): | |
| from src.visualization.svg_theme import to_standalone_svg | |
| standalone = to_standalone_svg(svg) | |
| svg_b64 = base64.b64encode(standalone.encode("utf-8")).decode("ascii") | |
| parts.append( | |
| f'<a class="download-link" ' | |
| f'href="data:image/svg+xml;base64,{svg_b64}" ' | |
| f'download="chart.svg">' | |
| f'<span class="icon">↓</span> SVG</a>' | |
| ) | |
| if not parts: | |
| return "" | |
| return f'<div class="downloads">{"".join(parts)}</div>' | |
| def _suggestions_html(qs: list[str]) -> str: | |
| if not qs: | |
| return "" | |
| chips = "".join( | |
| f'<span class="suggestion-chip" onclick="document.querySelector(\'textarea\').value=this.textContent;document.querySelector(\'textarea\').focus();">{q}</span>' | |
| for q in qs | |
| ) | |
| return f'<div class="suggestions">{chips}</div>' | |
| def _data_table_html(rows: list[dict], max_rows: int = 10) -> str: | |
| if not rows: | |
| return '<div class="empty-sub" style="padding:8px 0">No rows.</div>' | |
| df = pd.DataFrame(rows[:max_rows]) | |
| cols = df.columns.tolist() | |
| head = "".join(f"<th>{c}</th>" for c in cols) | |
| body = "".join( | |
| "<tr>" + "".join( | |
| f"<td>{('' if r.get(c) is None else r.get(c, ''))}</td>" for c in cols | |
| ) + "</tr>" | |
| for r in rows[:max_rows] | |
| ) | |
| note = ( | |
| f'<div class="data-table-meta">Showing {min(max_rows, len(rows))} of {len(rows):,} rows</div>' | |
| if len(rows) > max_rows else "" | |
| ) | |
| return f'<table class="data-table"><thead><tr>{head}</tr></thead><tbody>{body}</tbody></table>{note}' | |
| PIPELINE_STAGES = [ | |
| ("sql", "Generating SQL"), | |
| ("execute", "Running query"), | |
| ("chart", "Designing chart"), | |
| ("render", "Rendering visualization"), | |
| ] | |
| def _turn_html_progress(question: str, current_stage: str = "sql") -> str: | |
| """Render a turn that's still in progress with multi-step pipeline view.""" | |
| items = [] | |
| seen_active = False | |
| for key, label in PIPELINE_STAGES: | |
| if key == current_stage: | |
| cls = "active" | |
| seen_active = True | |
| elif seen_active: | |
| cls = "pending" | |
| else: | |
| cls = "done" | |
| items.append( | |
| f'<div class="pipeline-step {cls}">' | |
| f'<span class="pipeline-dot"></span>' | |
| f'<span class="pipeline-label">{label}</span>' | |
| f'</div>' | |
| ) | |
| return ( | |
| '<div class="turn">' | |
| f'<div class="turn-question">{question}</div>' | |
| f'<div class="pipeline">{"".join(items)}</div>' | |
| '</div>' | |
| ) | |
| SQL_KEYWORDS = { | |
| "SELECT", "FROM", "WHERE", "GROUP", "BY", "ORDER", "HAVING", "LIMIT", | |
| "OFFSET", "JOIN", "LEFT", "RIGHT", "INNER", "OUTER", "FULL", "ON", | |
| "AS", "AND", "OR", "NOT", "IN", "BETWEEN", "LIKE", "IS", "NULL", | |
| "DISTINCT", "WITH", "UNION", "ALL", "CASE", "WHEN", "THEN", "ELSE", | |
| "END", "ASC", "DESC", "INSERT", "INTO", "VALUES", "UPDATE", "SET", | |
| "DELETE", "CREATE", "TABLE", "DROP", "ALTER", | |
| } | |
| SQL_FUNCTIONS = { | |
| "COUNT", "SUM", "AVG", "MIN", "MAX", "DATE", "EXTRACT", "CAST", | |
| "COALESCE", "ROUND", "ABS", "LENGTH", "LOWER", "UPPER", "SUBSTR", | |
| "TRIM", "REPLACE", "CONCAT", "STRFTIME", "DATEDIFF", "NOW", | |
| } | |
| def _highlight_sql(sql: str) -> str: | |
| """Tokenize and add span highlighting to a SQL query.""" | |
| import re as _re, html as _html | |
| out = [] | |
| i = 0 | |
| n = len(sql) | |
| while i < n: | |
| c = sql[i] | |
| # Comments: -- to end of line | |
| if c == "-" and i + 1 < n and sql[i + 1] == "-": | |
| j = sql.find("\n", i) | |
| if j == -1: j = n | |
| out.append(f'<span class="sql-cmt">{_html.escape(sql[i:j])}</span>') | |
| i = j | |
| # Strings | |
| elif c in ("'", '"'): | |
| quote = c | |
| j = i + 1 | |
| while j < n and sql[j] != quote: | |
| if sql[j] == "\\" and j + 1 < n: | |
| j += 2 | |
| else: | |
| j += 1 | |
| j = min(j + 1, n) | |
| out.append(f'<span class="sql-str">{_html.escape(sql[i:j])}</span>') | |
| i = j | |
| # Numbers | |
| elif c.isdigit(): | |
| j = i | |
| while j < n and (sql[j].isdigit() or sql[j] == "."): | |
| j += 1 | |
| out.append(f'<span class="sql-num">{sql[i:j]}</span>') | |
| i = j | |
| # Identifiers / keywords / functions | |
| elif c.isalpha() or c == "_": | |
| j = i | |
| while j < n and (sql[j].isalnum() or sql[j] == "_"): | |
| j += 1 | |
| word = sql[i:j] | |
| upper = word.upper() | |
| if upper in SQL_KEYWORDS: | |
| out.append(f'<span class="sql-kw">{word}</span>') | |
| elif upper in SQL_FUNCTIONS: | |
| out.append(f'<span class="sql-fn">{word}</span>') | |
| else: | |
| out.append(_html.escape(word)) | |
| i = j | |
| else: | |
| out.append(_html.escape(c)) | |
| i += 1 | |
| return "".join(out) | |
| def _chart_badge(spec: dict, results_count: int) -> str: | |
| """Small uppercase badge showing chart type and row count.""" | |
| if not spec: | |
| return "" | |
| chart_type = (spec.get("chart_type") or "?").upper() | |
| if results_count > 30: | |
| return f'<div class="chart-badge">{chart_type} · top 30 of {results_count}</div>' | |
| return f'<div class="chart-badge">{chart_type} · {results_count} rows</div>' | |
| def _turn_html_complete(result: dict) -> str: | |
| """Render a finished turn.""" | |
| parts: list[str] = [f'<div class="turn-question">{result["question"]}</div>'] | |
| if result.get("error"): | |
| parts.append(f'<div class="turn-error">{result["error"]}</div>') | |
| if result.get("svg"): | |
| badge = _chart_badge(result.get("chart_spec"), len(result.get("results") or [])) | |
| parts.append( | |
| f'<div class="chart-wrap">{badge}{result["svg"]}</div>' | |
| ) | |
| # Narration: 1-2 sentence finding from the analyst persona | |
| if result.get("narration"): | |
| parts.append(f'<div class="narration">{result["narration"]}</div>') | |
| # Inline download links for CSV + SVG | |
| parts.append(_download_links_html( | |
| result.get("sql") or "", | |
| result.get("results") or [], | |
| result.get("svg") or "", | |
| )) | |
| if result.get("sql"): | |
| sql_safe = _highlight_sql(result["sql"]) | |
| sql_raw = result["sql"].replace("`", "\\`").replace("\\", "\\\\") | |
| copy_js = ( | |
| f"navigator.clipboard.writeText(`{sql_raw}`);" | |
| "this.textContent='Copied';setTimeout(()=>{this.textContent='Copy'},1200);" | |
| ) | |
| parts.append( | |
| '<details open><summary>SQL query</summary>' | |
| '<div class="sql-wrap">' | |
| f'<button class="sql-copy" onclick="{copy_js}">Copy</button>' | |
| f'<pre class="sql-block">{sql_safe}</pre>' | |
| '</div>' | |
| '</details>' | |
| ) | |
| if result.get("results"): | |
| parts.append( | |
| '<details><summary>Data</summary>' | |
| f'{_data_table_html(result["results"])}' | |
| '</details>' | |
| ) | |
| return f'<div class="turn">{"".join(parts)}</div>' | |
| def _conversation_html(history: list[dict], in_progress: tuple[str, str] | None = None) -> str: | |
| """Conversation HTML reacts to: data loaded? history? in-progress? | |
| in_progress is a tuple of (question, current_stage_key).""" | |
| has_data = bool(get_agent().list_tables()) | |
| has_turns = bool(history) or in_progress is not None | |
| if not has_data and not has_turns: | |
| return _empty_state_html() | |
| if has_data and not has_turns: | |
| return _ready_state_html() | |
| out = "".join(_turn_html_complete(t) for t in history) | |
| if in_progress: | |
| out += _turn_html_progress(in_progress[0], in_progress[1]) | |
| return out | |
| def _empty_state_html() -> str: | |
| """No data loaded — polished placeholder for the right panel.""" | |
| icon = ( | |
| '<svg viewBox="0 0 64 64" width="64" height="64" fill="none" ' | |
| 'stroke="currentColor" stroke-width="1.25" stroke-linecap="round" ' | |
| 'stroke-linejoin="round" style="opacity:0.35">' | |
| '<rect x="8" y="14" width="48" height="38" rx="4"/>' | |
| '<line x1="8" y1="24" x2="56" y2="24"/>' | |
| '<line x1="22" y1="14" x2="22" y2="52"/>' | |
| '<line x1="36" y1="14" x2="36" y2="52"/>' | |
| '<circle cx="46" cy="36" r="6"/>' | |
| '<line x1="50" y1="40" x2="56" y2="46"/>' | |
| '</svg>' | |
| ) | |
| return ( | |
| '<div class="empty empty-large">' | |
| f'<div class="empty-icon">{icon}</div>' | |
| '<div class="empty-title">No data loaded</div>' | |
| '<div class="empty-sub">Upload a CSV, JSON, Parquet or Excel file ' | |
| 'on the left, or click <strong>Demo</strong> to try sample data.</div>' | |
| '</div>' | |
| ) | |
| def _ready_state_html() -> str: | |
| """Shown when data is loaded but no queries asked yet. | |
| Suggestions are derived from the actual loaded table's columns.""" | |
| agent = get_agent() | |
| tables = agent.list_tables() | |
| suggestions: list[str] = [] | |
| if tables: | |
| schema = agent.executor.get_table_schema(tables[0]) | |
| suggestions = _suggest_questions(tables[0], schema) | |
| return ( | |
| '<div class="empty">' | |
| '<div class="empty-title">Ready</div>' | |
| '<div class="empty-sub">Ask a question above, or try one of these:</div>' | |
| f'{_suggestions_html(suggestions)}' | |
| '</div>' | |
| ) | |
| # ============================================================ EVENT HANDLERS | |
| def _build_schema_html(table: str) -> str: | |
| agent = get_agent() | |
| schema = agent.executor.get_table_schema(table) | |
| return _schema_preview_html(table, schema) | |
| def _placeholder_for_table(table: str) -> str: | |
| """Return a smart placeholder based on the loaded table's columns.""" | |
| try: | |
| agent = get_agent() | |
| if not agent.list_tables(): | |
| return "Ask anything about your data…" | |
| schema = agent.executor.get_table_schema(table) | |
| suggestions = _suggest_questions(table, schema) | |
| if suggestions: | |
| return f"Try: {suggestions[0]}" | |
| except Exception: | |
| pass | |
| return "Ask anything about your data…" | |
| def on_upload(file): | |
| """Returns: (chip, schema_html, conversation, history, question_update).""" | |
| if file is None: | |
| return "", "", _conversation_html([]), [], gr.update(placeholder="Ask anything about your data…") | |
| agent = get_agent() | |
| agent.reset() | |
| try: | |
| path = Path(file.name if hasattr(file, "name") else file) | |
| table = agent.load_data(path) | |
| rows = agent.executor.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0] | |
| cols = len(agent.executor.get_table_schema(table)) | |
| chip = _file_chip_html(path.name, rows, cols) | |
| schema = _build_schema_html(table) | |
| placeholder = _placeholder_for_table(table) | |
| return chip, schema, _conversation_html([]), [], gr.update(placeholder=placeholder) | |
| except Exception as e: | |
| logger.exception("upload failed") | |
| return "", "", f'<div class="turn-error">Could not load file: {e}</div>', [], gr.update() | |
| def on_load_demo(): | |
| """Returns: (chip, schema_html, conversation, history, question_update).""" | |
| agent = get_agent() | |
| agent.reset() | |
| try: | |
| p = _make_titanic_csv() | |
| table = agent.load_data(p) | |
| rows = agent.executor.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0] | |
| cols = len(agent.executor.get_table_schema(table)) | |
| chip = _file_chip_html("titanic.csv (demo)", rows, cols) | |
| schema = _build_schema_html(table) | |
| placeholder = _placeholder_for_table(table) | |
| return chip, schema, _conversation_html([]), [], gr.update(placeholder=placeholder) | |
| except Exception as e: | |
| logger.exception("demo load failed") | |
| return "", "", f'<div class="turn-error">Could not load demo: {e}</div>', [], gr.update() | |
| def _gpu_process(question: str) -> dict: | |
| """Inference only — models already on cuda from module-level loading.""" | |
| agent = get_agent() | |
| return agent.process(question) | |
| def on_ask(question: str, history: list) -> Generator[Tuple[str, str, list], None, None]: | |
| """ | |
| Generator: yields conversation HTML at each pipeline step so the user | |
| sees real-time progress instead of waiting silently. | |
| """ | |
| # Single-shot mode: each Ask is a fresh query, replaces any previous result | |
| question = (question or "").strip() | |
| if not question: | |
| yield _conversation_html([]), "", [] | |
| return | |
| if not get_agent().list_tables(): | |
| result = { | |
| "question": question, | |
| "error": "Upload a file first or load the demo dataset.", | |
| } | |
| yield _conversation_html([result]), "", [result] | |
| return | |
| # Stream pipeline stages so user sees progress (4 visual steps). | |
| # The actual GPU call is one shot (yielding mid-call would break the | |
| # @spaces.GPU window) but we still surface the stages around it. | |
| yield _conversation_html([], in_progress=(question, "sql")), "", [] | |
| try: | |
| result = _gpu_process(question) | |
| except Exception as e: | |
| logger.exception("ask failed") | |
| result = {"question": question, "error": str(e)} | |
| yield _conversation_html([result]), "", [result] | |
| def on_reset(): | |
| """Clear ONLY the dataset (file, chip, schema, agent state). | |
| Also wipes the displayed result since it's now stale without data. | |
| Order: upload, chip_html, schema_html, conversation, history_state | |
| """ | |
| get_agent().reset() | |
| return None, "", "", _conversation_html([]), [] | |
| # ====================================================================== APP | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks( | |
| theme=gr.themes.Base(), | |
| css=THEME_CSS, | |
| title="SQL Agent", | |
| analytics_enabled=False, | |
| ) as demo: | |
| # Header | |
| gr.HTML( | |
| '<div class="app-header">' | |
| '<div>' | |
| '<div class="app-title">SQL Agent</div>' | |
| '<div class="app-subtitle">Ask anything about your data.</div>' | |
| '</div>' | |
| '</div>' | |
| ) | |
| # Two-column layout: LEFT = data/question controls, RIGHT = results | |
| with gr.Row(elem_classes=["split-layout"]): | |
| # ---------- LEFT panel ---------- | |
| with gr.Column(elem_classes=["split-left"], scale=0): | |
| gr.HTML('<div class="panel-label">Dataset</div>') | |
| with gr.Row(elem_classes=["upload-row"]): | |
| upload = gr.File( | |
| label="", | |
| file_types=[".csv", ".json", ".parquet", ".xlsx", ".xls"], | |
| show_label=False, | |
| container=False, | |
| ) | |
| chip_html = gr.HTML("") | |
| schema_html = gr.HTML("") | |
| gr.HTML('<div class="panel-label" style="margin-top:18px">Question</div>') | |
| with gr.Group(elem_classes=["question-row"]): | |
| question = gr.Textbox( | |
| placeholder="Ask anything about your data…", | |
| lines=3, | |
| max_lines=10, | |
| show_label=False, | |
| container=False, | |
| autofocus=True, | |
| ) | |
| gr.HTML('<div class="kb-hint">Press Enter to send · Shift+Enter for newline</div>') | |
| with gr.Row(elem_classes=["btn-row"]): | |
| ask_btn = gr.Button("Ask", variant="primary", size="sm", scale=2) | |
| reset_btn = gr.Button("Clear", variant="secondary", size="sm", scale=1) | |
| demo_btn = gr.Button("Demo", variant="secondary", size="sm", | |
| elem_id="load_demo_btn", scale=1) | |
| # ---------- RIGHT panel ---------- | |
| with gr.Column(elem_classes=["split-right"], scale=1): | |
| gr.HTML('<div class="panel-label">Result</div>') | |
| history_state = gr.State([]) | |
| conversation = gr.HTML(_conversation_html([])) | |
| # ------------- events ------------- | |
| upload.upload( | |
| fn=on_upload, | |
| inputs=upload, | |
| outputs=[chip_html, schema_html, conversation, history_state, question], | |
| api_name=False, | |
| ) | |
| upload.clear( | |
| fn=lambda: ("", "", _conversation_html([]), [], gr.update(placeholder="Ask anything about your data…")), | |
| outputs=[chip_html, schema_html, conversation, history_state, question], | |
| api_name=False, | |
| ) | |
| demo_btn.click( | |
| fn=on_load_demo, | |
| outputs=[chip_html, schema_html, conversation, history_state, question], | |
| api_name=False, | |
| ) | |
| ask_btn.click( | |
| fn=on_ask, | |
| inputs=[question, history_state], | |
| outputs=[conversation, question, history_state], | |
| api_name=False, | |
| ) | |
| question.submit( | |
| fn=on_ask, | |
| inputs=[question, history_state], | |
| outputs=[conversation, question, history_state], | |
| api_name=False, | |
| ) | |
| reset_btn.click( | |
| fn=on_reset, | |
| outputs=[upload, chip_html, schema_html, conversation, history_state], | |
| api_name=False, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.queue(api_open=False).launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| show_api=False, | |
| ) | |