ar9avg commited on
Commit
f0b682f
·
1 Parent(s): 68ebe84

fix: GEPA current_generation, task_id mapping, Connect DB button, remove difficulty from header

Browse files
backend/api/demo.py CHANGED
@@ -27,7 +27,16 @@ from env.database import (
27
  get_schema_info,
28
  get_schema_graph,
29
  execute_query,
 
 
30
  )
 
 
 
 
 
 
 
31
  from env.tasks import TASKS, get_task
32
  from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, _clean_sql
33
  from rl.environment import get_bandit_state
@@ -46,7 +55,22 @@ router = APIRouter()
46
  async def init_db():
47
  seeded = ensure_seeded()
48
  tables = get_table_stats()
49
- return {"tables": tables, "seeded": seeded}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  # ─── /api/execute-query ───────────────────────────────────────────
@@ -60,10 +84,12 @@ class ExecuteQueryRequest(BaseModel):
60
  async def execute_query_stream(req: ExecuteQueryRequest):
61
  async def event_generator() -> AsyncIterator[dict]:
62
  env = get_env()
63
- obs = env.reset(req.task_id)
 
 
64
 
65
  # Pick first question of task matching question text, or default
66
- task = get_task(req.task_id)
67
  question_obj = task.questions[0]
68
  # Override question text
69
  env._episode.question = req.question # type: ignore[union-attr]
@@ -155,7 +181,7 @@ async def execute_query_stream(req: ExecuteQueryRequest):
155
 
156
  from env.tasks import grade_response
157
  task_score = grade_response(
158
- req.task_id, question_obj.id, generated_sql, rows, error, attempt
159
  )
160
  attempt_success = task_score >= 0.8
161
 
@@ -303,7 +329,8 @@ class BenchmarkRequest(BaseModel):
303
  @router.post("/benchmark")
304
  async def run_benchmark(req: BenchmarkRequest):
305
  async def event_generator() -> AsyncIterator[dict]:
306
- task = get_task(req.task_id)
 
307
  scores: list[float] = []
308
 
309
  for question_obj in task.questions:
@@ -315,7 +342,7 @@ async def run_benchmark(req: BenchmarkRequest):
315
 
316
  # Run the question through the env
317
  env = SQLAgentEnv()
318
- obs = env.reset_with_question(req.task_id, question_obj.id)
319
 
320
  attempt = 0
321
  sql = ""
@@ -373,7 +400,7 @@ async def run_benchmark(req: BenchmarkRequest):
373
  rows, error = execute_query(sql)
374
  from env.tasks import grade_response
375
  task_score = grade_response(
376
- req.task_id, question_obj.id, sql, rows, error, attempt
377
  )
378
  success = task_score >= 0.8
379
 
@@ -428,7 +455,7 @@ async def run_benchmark(req: BenchmarkRequest):
428
  yield {"data": json.dumps({
429
  "type": "done",
430
  "overall_score": overall_score,
431
- "task_id": req.task_id,
432
  })}
433
 
434
  return EventSourceResponse(event_generator())
 
27
  get_schema_info,
28
  get_schema_graph,
29
  execute_query,
30
+ connect_external_db,
31
+ get_active_db_label,
32
  )
33
+
34
+ # Map frontend difficulty names → backend task IDs
35
+ _DIFFICULTY_MAP = {
36
+ "easy": "simple_queries",
37
+ "medium": "join_queries",
38
+ "hard": "complex_queries",
39
+ }
40
  from env.tasks import TASKS, get_task
41
  from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, _clean_sql
42
  from rl.environment import get_bandit_state
 
55
  async def init_db():
56
  seeded = ensure_seeded()
57
  tables = get_table_stats()
58
+ return {"tables": tables, "seeded": seeded, "dbLabel": get_active_db_label()}
59
+
60
+
61
+ # ─── /api/connect-db ──────────────────────────────────────────────
62
+
63
+ class ConnectDbRequest(BaseModel):
64
+ path: str # SQLite file path or :memory:
65
+
66
+
67
+ @router.post("/connect-db")
68
+ async def connect_db(req: ConnectDbRequest):
69
+ success, message = connect_external_db(req.path)
70
+ if success:
71
+ tables = get_table_stats()
72
+ return {"success": True, "message": message, "tables": tables, "dbLabel": get_active_db_label()}
73
+ return {"success": False, "message": message, "tables": [], "dbLabel": get_active_db_label()}
74
 
75
 
76
  # ─── /api/execute-query ───────────────────────────────────────────
 
84
  async def execute_query_stream(req: ExecuteQueryRequest):
85
  async def event_generator() -> AsyncIterator[dict]:
86
  env = get_env()
87
+ # Accept difficulty names ('easy'/'medium'/'hard') or direct task IDs
88
+ task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id)
89
+ obs = env.reset(task_id)
90
 
91
  # Pick first question of task matching question text, or default
92
+ task = get_task(task_id)
93
  question_obj = task.questions[0]
94
  # Override question text
95
  env._episode.question = req.question # type: ignore[union-attr]
 
181
 
182
  from env.tasks import grade_response
183
  task_score = grade_response(
184
+ task_id, question_obj.id, generated_sql, rows, error, attempt
185
  )
186
  attempt_success = task_score >= 0.8
187
 
 
329
  @router.post("/benchmark")
330
  async def run_benchmark(req: BenchmarkRequest):
331
  async def event_generator() -> AsyncIterator[dict]:
332
+ task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id)
333
+ task = get_task(task_id)
334
  scores: list[float] = []
335
 
336
  for question_obj in task.questions:
 
342
 
343
  # Run the question through the env
344
  env = SQLAgentEnv()
345
+ obs = env.reset_with_question(task_id, question_obj.id)
346
 
347
  attempt = 0
348
  sql = ""
 
400
  rows, error = execute_query(sql)
401
  from env.tasks import grade_response
402
  task_score = grade_response(
403
+ task_id, question_obj.id, sql, rows, error, attempt
404
  )
405
  success = task_score >= 0.8
406
 
 
455
  yield {"data": json.dumps({
456
  "type": "done",
457
  "overall_score": overall_score,
458
+ "task_id": task_id,
459
  })}
460
 
461
  return EventSourceResponse(event_generator())
backend/env/database.py CHANGED
@@ -21,6 +21,34 @@ from typing import Any
21
  _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
22
  DB_PATH = _DATA_DIR / "benchmark.db"
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # ─── Schema ───────────────────────────────────────────────────────
26
 
@@ -350,10 +378,14 @@ def get_schema_info() -> str:
350
  """
351
  Return a concise textual schema summary for use in prompts.
352
  """
353
- conn = sqlite3.connect(str(DB_PATH))
354
  try:
 
 
 
 
355
  lines = []
356
- for table in ["sellers", "users", "products", "orders", "reviews"]:
357
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
358
  cols = ", ".join(
359
  f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}"
@@ -371,7 +403,7 @@ def execute_query(sql: str) -> tuple[list[dict], str | None]:
371
  Execute a SQL query and return (rows, error_message).
372
  rows is a list of dicts; error_message is None on success.
373
  """
374
- conn = sqlite3.connect(str(DB_PATH))
375
  conn.row_factory = sqlite3.Row
376
  try:
377
  cursor = conn.execute(sql)
@@ -385,9 +417,10 @@ def execute_query(sql: str) -> tuple[list[dict], str | None]:
385
 
386
  def get_table_stats() -> list[dict]:
387
  """Return [{name, rows}, ...] for all tables."""
388
- conn = sqlite3.connect(str(DB_PATH))
389
  try:
390
- tables = ["sellers", "users", "products", "orders", "reviews"]
 
391
  return [
392
  {
393
  "name": t,
 
21
  _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
22
  DB_PATH = _DATA_DIR / "benchmark.db"
23
 
24
+ # Active DB path — can be overridden via connect_external_db()
25
+ _active_db_path: str = str(DB_PATH)
26
+ _active_db_label: str = "benchmark (built-in)"
27
+
28
+
29
+ def connect_external_db(path: str) -> tuple[bool, str]:
30
+ """Switch the active SQLite database. Returns (success, message)."""
31
+ global _active_db_path, _active_db_label
32
+ try:
33
+ conn = sqlite3.connect(path)
34
+ tables = conn.execute(
35
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
36
+ ).fetchall()
37
+ conn.close()
38
+ _active_db_path = path
39
+ _active_db_label = Path(path).name
40
+ return True, f"Connected to {Path(path).name} ({len(tables)} tables)"
41
+ except Exception as e:
42
+ return False, str(e)
43
+
44
+
45
+ def get_active_db_label() -> str:
46
+ return _active_db_label
47
+
48
+
49
+ def _get_db_path() -> str:
50
+ return _active_db_path
51
+
52
 
53
  # ─── Schema ───────────────────────────────────────────────────────
54
 
 
378
  """
379
  Return a concise textual schema summary for use in prompts.
380
  """
381
+ conn = sqlite3.connect(_get_db_path())
382
  try:
383
+ cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
384
+ tables_in_db = [r[0] for r in cur.fetchall()]
385
+ # Fall back to all tables if schema is unknown
386
+ tables = tables_in_db if tables_in_db else ["sellers", "users", "products", "orders", "reviews"]
387
  lines = []
388
+ for table in tables:
389
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
390
  cols = ", ".join(
391
  f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}"
 
403
  Execute a SQL query and return (rows, error_message).
404
  rows is a list of dicts; error_message is None on success.
405
  """
406
+ conn = sqlite3.connect(_get_db_path())
407
  conn.row_factory = sqlite3.Row
408
  try:
409
  cursor = conn.execute(sql)
 
417
 
418
  def get_table_stats() -> list[dict]:
419
  """Return [{name, rows}, ...] for all tables."""
420
+ conn = sqlite3.connect(_get_db_path())
421
  try:
422
+ cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
423
+ tables = [r[0] for r in cur.fetchall()] or ["sellers", "users", "products", "orders", "reviews"]
424
  return [
425
  {
426
  "name": t,
backend/gepa/optimizer.py CHANGED
@@ -59,7 +59,7 @@ class Candidate(BaseModel):
59
  def _make_client() -> AsyncOpenAI:
60
  return AsyncOpenAI(
61
  api_key=os.environ.get("HF_TOKEN", ""),
62
- base_url=os.environ.get("API_BASE_URL", "https://api.openai.com/v1"),
63
  )
64
 
65
 
@@ -158,6 +158,12 @@ class GEPAOptimizer:
158
  )
159
  self._save()
160
 
 
 
 
 
 
 
161
  def should_optimize(self) -> bool:
162
  return len(self._history) > 0 and len(self._history) % 4 == 0
163
 
 
59
  def _make_client() -> AsyncOpenAI:
60
  return AsyncOpenAI(
61
  api_key=os.environ.get("HF_TOKEN", ""),
62
+ base_url=os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1"),
63
  )
64
 
65
 
 
158
  )
159
  self._save()
160
 
161
+ @property
162
+ def current_generation(self) -> int:
163
+ if not self._pareto_front:
164
+ return 0
165
+ return max(c.generation for c in self._pareto_front)
166
+
167
  def should_optimize(self) -> bool:
168
  return len(self._history) > 0 and len(self._history) % 4 == 0
169
 
frontend/src/App.tsx CHANGED
@@ -9,6 +9,7 @@ import { BenchmarkPanel } from './components/BenchmarkPanel'
9
  import { ERDiagram } from './components/ERDiagram'
10
  import { RightSidebar } from './components/RightSidebar'
11
  import { DemoMode } from './components/DemoMode'
 
12
  import { useStore } from './store/useStore'
13
  import { fetchInit } from './lib/api'
14
 
@@ -25,8 +26,9 @@ export default function App() {
25
  const [leftOpen, setLeftOpen] = useState(false)
26
  const [rightOpen, setRightOpen] = useState(false)
27
  const [demoOpen, setDemoOpen] = useState(false)
 
28
 
29
- const { theme, setDbSeeded, setTables, setSchemaGraph } = useStore()
30
 
31
  // Apply theme on mount / change
32
  useEffect(() => {
@@ -50,6 +52,7 @@ export default function App() {
50
  .then((d) => {
51
  setDbSeeded(true)
52
  setTables(d.tables)
 
53
  // Lazy-load schema graph
54
  fetch('/api/schema-graph')
55
  .then((r) => r.json())
@@ -74,11 +77,15 @@ export default function App() {
74
  onToggleLeft={() => { setLeftOpen((v) => !v); setRightOpen(false) }}
75
  onToggleRight={() => { setRightOpen((v) => !v); setLeftOpen(false) }}
76
  onDemo={() => setDemoOpen(true)}
 
77
  />
78
 
79
  <AnimatePresence>
80
  {demoOpen && <DemoMode onClose={() => setDemoOpen(false)} />}
81
  </AnimatePresence>
 
 
 
82
 
83
  <div className="flex flex-1 overflow-hidden relative">
84
  {/* Overlay backdrop (mobile) */}
 
9
  import { ERDiagram } from './components/ERDiagram'
10
  import { RightSidebar } from './components/RightSidebar'
11
  import { DemoMode } from './components/DemoMode'
12
+ import { ConnectDB } from './components/ConnectDB'
13
  import { useStore } from './store/useStore'
14
  import { fetchInit } from './lib/api'
15
 
 
26
  const [leftOpen, setLeftOpen] = useState(false)
27
  const [rightOpen, setRightOpen] = useState(false)
28
  const [demoOpen, setDemoOpen] = useState(false)
29
+ const [connectDbOpen, setConnectDbOpen] = useState(false)
30
 
31
+ const { theme, setDbSeeded, setTables, setSchemaGraph, setDbLabel } = useStore()
32
 
33
  // Apply theme on mount / change
34
  useEffect(() => {
 
52
  .then((d) => {
53
  setDbSeeded(true)
54
  setTables(d.tables)
55
+ if (d.dbLabel) setDbLabel(d.dbLabel)
56
  // Lazy-load schema graph
57
  fetch('/api/schema-graph')
58
  .then((r) => r.json())
 
77
  onToggleLeft={() => { setLeftOpen((v) => !v); setRightOpen(false) }}
78
  onToggleRight={() => { setRightOpen((v) => !v); setLeftOpen(false) }}
79
  onDemo={() => setDemoOpen(true)}
80
+ onConnectDb={() => setConnectDbOpen(true)}
81
  />
82
 
83
  <AnimatePresence>
84
  {demoOpen && <DemoMode onClose={() => setDemoOpen(false)} />}
85
  </AnimatePresence>
86
+ <AnimatePresence>
87
+ {connectDbOpen && <ConnectDB onClose={() => setConnectDbOpen(false)} />}
88
+ </AnimatePresence>
89
 
90
  <div className="flex flex-1 overflow-hidden relative">
91
  {/* Overlay backdrop (mobile) */}
frontend/src/components/BenchmarkPanel.tsx CHANGED
@@ -251,6 +251,23 @@ export function BenchmarkPanel() {
251
  <div className="flex flex-col h-full">
252
  {/* Header */}
253
  <div className="px-4 py-3 border-b border-white/[0.06] shrink-0">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  <div className="flex items-center justify-between mb-2">
255
  <div className="flex items-center gap-2">
256
  <Target size={14} className="text-violet-400" />
 
251
  <div className="flex flex-col h-full">
252
  {/* Header */}
253
  <div className="px-4 py-3 border-b border-white/[0.06] shrink-0">
254
+ {/* Difficulty tabs */}
255
+ <div className="flex items-center gap-1 mb-2.5 p-0.5 rounded-lg border border-white/[0.06] w-fit">
256
+ {DIFFICULTY_TABS.map((tab) => (
257
+ <button
258
+ key={tab.id}
259
+ onClick={() => setTaskDifficulty(tab.id)}
260
+ disabled={isBenchmarking}
261
+ className={`text-[10px] font-semibold px-2.5 py-1 rounded transition-all disabled:opacity-50 ${
262
+ taskDifficulty === tab.id
263
+ ? 'bg-violet-600/25 text-violet-300 border border-violet-500/30'
264
+ : 'text-gray-500 hover:text-gray-300 border border-transparent'
265
+ }`}
266
+ >
267
+ {tab.label}
268
+ </button>
269
+ ))}
270
+ </div>
271
  <div className="flex items-center justify-between mb-2">
272
  <div className="flex items-center gap-2">
273
  <Target size={14} className="text-violet-400" />
frontend/src/components/ConnectDB.tsx ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState } from 'react'
2
+ import { motion } from 'framer-motion'
3
+ import { X, PlugZap, Database, CheckCircle2, XCircle, Loader2, RotateCcw } from 'lucide-react'
4
+ import { useStore } from '../store/useStore'
5
+ import { connectExternalDb } from '../lib/api'
6
+
7
+ interface ConnectDBProps {
8
+ onClose: () => void
9
+ }
10
+
11
+ const EXAMPLES = [
12
+ { label: 'In-memory (blank)', value: ':memory:' },
13
+ { label: 'Custom path', value: '/path/to/your/database.db' },
14
+ ]
15
+
16
+ export function ConnectDB({ onClose }: ConnectDBProps) {
17
+ const { dbLabel, setDbLabel, setTables, setDbSeeded } = useStore()
18
+ const [path, setPath] = useState('')
19
+ const [status, setStatus] = useState<'idle' | 'connecting' | 'success' | 'error'>('idle')
20
+ const [message, setMessage] = useState('')
21
+
22
+ const handleConnect = async () => {
23
+ if (!path.trim()) return
24
+ setStatus('connecting')
25
+ setMessage('')
26
+ try {
27
+ const res = await connectExternalDb(path.trim())
28
+ if (res.success) {
29
+ setDbLabel(res.dbLabel)
30
+ setTables(res.tables)
31
+ setDbSeeded(true)
32
+ setStatus('success')
33
+ setMessage(res.message)
34
+ } else {
35
+ setStatus('error')
36
+ setMessage(res.message)
37
+ }
38
+ } catch (e) {
39
+ setStatus('error')
40
+ setMessage(e instanceof Error ? e.message : 'Connection failed')
41
+ }
42
+ }
43
+
44
+ const handleReset = async () => {
45
+ setStatus('connecting')
46
+ try {
47
+ const res = await connectExternalDb('/app/backend/data/benchmark.db')
48
+ if (res.success) {
49
+ setDbLabel(res.dbLabel)
50
+ setTables(res.tables)
51
+ setDbSeeded(true)
52
+ setStatus('success')
53
+ setMessage('Reset to built-in benchmark database')
54
+ } else {
55
+ setStatus('error')
56
+ setMessage(res.message)
57
+ }
58
+ } catch (e) {
59
+ setStatus('error')
60
+ setMessage(e instanceof Error ? e.message : 'Reset failed')
61
+ }
62
+ }
63
+
64
+ return (
65
+ <motion.div
66
+ initial={{ opacity: 0 }}
67
+ animate={{ opacity: 1 }}
68
+ exit={{ opacity: 0 }}
69
+ className="fixed inset-0 z-[200] flex items-center justify-center p-4"
70
+ style={{ background: 'rgba(0,0,0,0.6)', backdropFilter: 'blur(4px)' }}
71
+ onClick={(e) => { if (e.target === e.currentTarget) onClose() }}
72
+ >
73
+ <motion.div
74
+ initial={{ scale: 0.95, opacity: 0, y: 8 }}
75
+ animate={{ scale: 1, opacity: 1, y: 0 }}
76
+ exit={{ scale: 0.95, opacity: 0 }}
77
+ transition={{ duration: 0.15 }}
78
+ className="w-full max-w-md rounded-2xl border shadow-2xl overflow-hidden"
79
+ style={{ background: 'var(--bg-secondary)', borderColor: 'var(--border-color)' }}
80
+ >
81
+ {/* Header */}
82
+ <div className="flex items-center justify-between px-5 py-4 border-b" style={{ borderColor: 'var(--border-color)' }}>
83
+ <div className="flex items-center gap-2.5">
84
+ <div className="w-7 h-7 rounded-lg flex items-center justify-center" style={{ background: 'linear-gradient(135deg,#1e3a5f,#2d1b69)' }}>
85
+ <PlugZap size={13} className="text-white" />
86
+ </div>
87
+ <div>
88
+ <h2 className="text-sm font-semibold theme-text-primary">Connect Database</h2>
89
+ <p className="text-[10px] text-gray-500">SQLite file path</p>
90
+ </div>
91
+ </div>
92
+ <button onClick={onClose} className="p-1.5 rounded-lg hover:bg-white/5 text-gray-500 hover:text-gray-300 transition-colors">
93
+ <X size={15} />
94
+ </button>
95
+ </div>
96
+
97
+ {/* Current DB */}
98
+ <div className="px-5 py-3 border-b flex items-center gap-2" style={{ borderColor: 'var(--border-color)', background: 'var(--bg-tertiary)' }}>
99
+ <Database size={11} className="text-violet-400 shrink-0" />
100
+ <span className="text-[11px] text-gray-500">Active:</span>
101
+ <span className="text-[11px] font-semibold text-violet-400 truncate">{dbLabel}</span>
102
+ <button
103
+ onClick={handleReset}
104
+ className="ml-auto flex items-center gap-1 text-[10px] text-gray-600 hover:text-gray-400 transition-colors"
105
+ title="Reset to built-in demo database"
106
+ >
107
+ <RotateCcw size={9} />
108
+ Reset to demo
109
+ </button>
110
+ </div>
111
+
112
+ {/* Body */}
113
+ <div className="px-5 py-4 flex flex-col gap-4">
114
+ <div>
115
+ <label className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider block mb-1.5">
116
+ SQLite File Path
117
+ </label>
118
+ <input
119
+ type="text"
120
+ value={path}
121
+ onChange={(e) => { setPath(e.target.value); setStatus('idle') }}
122
+ onKeyDown={(e) => e.key === 'Enter' && void handleConnect()}
123
+ placeholder="/path/to/database.db"
124
+ className="w-full px-3 py-2.5 text-sm rounded-xl border focus:outline-none transition-all font-mono"
125
+ style={{
126
+ background: 'var(--bg-tertiary)',
127
+ borderColor: 'var(--border-color)',
128
+ color: 'var(--text-primary)',
129
+ }}
130
+ autoFocus
131
+ />
132
+ </div>
133
+
134
+ {/* Quick examples */}
135
+ <div className="flex flex-col gap-1.5">
136
+ <span className="text-[10px] text-gray-600 uppercase tracking-wider">Quick select</span>
137
+ <div className="flex flex-wrap gap-1.5">
138
+ {EXAMPLES.map((ex) => (
139
+ <button
140
+ key={ex.value}
141
+ onClick={() => { setPath(ex.value); setStatus('idle') }}
142
+ className="text-[10px] px-2.5 py-1 rounded-full border transition-all text-gray-500 hover:text-gray-300"
143
+ style={{ borderColor: 'var(--border-color)', background: 'var(--bg-tertiary)' }}
144
+ >
145
+ {ex.label}
146
+ </button>
147
+ ))}
148
+ </div>
149
+ </div>
150
+
151
+ {/* Status message */}
152
+ {status !== 'idle' && (
153
+ <div className={`flex items-start gap-2 rounded-xl px-3 py-2.5 text-xs ${
154
+ status === 'success' ? 'bg-green-500/10 border border-green-500/20 text-green-400' :
155
+ status === 'error' ? 'bg-red-500/10 border border-red-500/20 text-red-400' :
156
+ 'bg-violet-500/10 border border-violet-500/20 text-violet-400'
157
+ }`}>
158
+ {status === 'connecting' && <Loader2 size={12} className="animate-spin shrink-0 mt-0.5" />}
159
+ {status === 'success' && <CheckCircle2 size={12} className="shrink-0 mt-0.5" />}
160
+ {status === 'error' && <XCircle size={12} className="shrink-0 mt-0.5" />}
161
+ <span>{status === 'connecting' ? 'Connecting…' : message}</span>
162
+ </div>
163
+ )}
164
+ </div>
165
+
166
+ {/* Footer */}
167
+ <div className="px-5 pb-5 flex items-center justify-end gap-2">
168
+ <button
169
+ onClick={onClose}
170
+ className="px-4 py-2 rounded-xl text-xs font-medium text-gray-500 hover:text-gray-300 transition-colors"
171
+ >
172
+ {status === 'success' ? 'Close' : 'Cancel'}
173
+ </button>
174
+ <button
175
+ onClick={() => void handleConnect()}
176
+ disabled={!path.trim() || status === 'connecting'}
177
+ className="flex items-center gap-1.5 px-4 py-2 rounded-xl text-xs font-semibold text-white transition-all active:scale-95 disabled:opacity-40 disabled:cursor-not-allowed"
178
+ style={{ background: 'linear-gradient(135deg,#7c3aed,#2563eb)' }}
179
+ >
180
+ {status === 'connecting' ? (
181
+ <><Loader2 size={11} className="animate-spin" /> Connecting…</>
182
+ ) : (
183
+ <><PlugZap size={11} /> Connect</>
184
+ )}
185
+ </button>
186
+ </div>
187
+ </motion.div>
188
+ </motion.div>
189
+ )
190
+ }
frontend/src/components/Header.tsx CHANGED
@@ -1,21 +1,15 @@
1
- import { Database, Sun, Moon, PanelLeftOpen, PanelRightOpen, Cpu, Play } from 'lucide-react'
2
  import { useStore } from '../store/useStore'
3
- import type { Difficulty } from '../lib/types'
4
 
5
  interface HeaderProps {
6
  onToggleLeft: () => void
7
  onToggleRight: () => void
8
  onDemo: () => void
 
9
  }
10
 
11
- const DIFFICULTIES: { id: Difficulty; label: string; color: string }[] = [
12
- { id: 'easy', label: 'Easy', color: 'text-green-400 border-green-500/30 bg-green-500/10' },
13
- { id: 'medium', label: 'Medium', color: 'text-amber-400 border-amber-500/30 bg-amber-500/10' },
14
- { id: 'hard', label: 'Hard', color: 'text-red-400 border-red-500/30 bg-red-500/10' },
15
- ]
16
-
17
- export function Header({ onToggleLeft, onToggleRight, onDemo }: HeaderProps) {
18
- const { theme, toggleTheme, dbSeeded, taskDifficulty, setTaskDifficulty } = useStore()
19
 
20
  return (
21
  <header
@@ -71,22 +65,16 @@ export function Header({ onToggleLeft, onToggleRight, onDemo }: HeaderProps) {
71
  LinUCB Active
72
  </div>
73
 
74
- {/* Difficulty selector */}
75
- <div className="flex items-center gap-1 border border-white/[0.06] rounded-lg p-0.5">
76
- {DIFFICULTIES.map((d) => (
77
- <button
78
- key={d.id}
79
- onClick={() => setTaskDifficulty(d.id)}
80
- className={`text-[10px] font-semibold px-2 py-1 rounded transition-all ${
81
- taskDifficulty === d.id
82
- ? `${d.color} border`
83
- : 'text-gray-500 hover:text-gray-300 border border-transparent'
84
- }`}
85
- >
86
- {d.label}
87
- </button>
88
- ))}
89
- </div>
90
 
91
  {/* Demo button */}
92
  <button
 
1
+ import { Database, Sun, Moon, PanelLeftOpen, PanelRightOpen, Cpu, Play, PlugZap } from 'lucide-react'
2
  import { useStore } from '../store/useStore'
 
3
 
4
  interface HeaderProps {
5
  onToggleLeft: () => void
6
  onToggleRight: () => void
7
  onDemo: () => void
8
+ onConnectDb: () => void
9
  }
10
 
11
+ export function Header({ onToggleLeft, onToggleRight, onDemo, onConnectDb }: HeaderProps) {
12
+ const { theme, toggleTheme, dbSeeded, dbLabel } = useStore()
 
 
 
 
 
 
13
 
14
  return (
15
  <header
 
65
  LinUCB Active
66
  </div>
67
 
68
+ {/* Connect DB button */}
69
+ <button
70
+ onClick={onConnectDb}
71
+ className="hidden sm:flex items-center gap-1.5 px-2.5 py-1.5 rounded-lg text-[11px] font-medium transition-all hover:bg-white/5 theme-border border"
72
+ style={{ color: 'var(--text-muted)' }}
73
+ title={`Active: ${dbLabel}`}
74
+ >
75
+ <PlugZap size={11} />
76
+ <span className="hidden md:inline">Connect DB</span>
77
+ </button>
 
 
 
 
 
 
78
 
79
  {/* Demo button */}
80
  <button
frontend/src/lib/api.ts CHANGED
@@ -95,3 +95,13 @@ export async function fetchPromptHistory() {
95
  if (!res.ok) throw new Error(`HTTP ${res.status}`)
96
  return res.json()
97
  }
 
 
 
 
 
 
 
 
 
 
 
95
  if (!res.ok) throw new Error(`HTTP ${res.status}`)
96
  return res.json()
97
  }
98
+
99
+ export async function connectExternalDb(path: string): Promise<{ success: boolean; message: string; tables: { name: string; rows: number }[]; dbLabel: string }> {
100
+ const res = await fetch(`${BASE_URL}/api/connect-db`, {
101
+ method: 'POST',
102
+ headers: { 'Content-Type': 'application/json' },
103
+ body: JSON.stringify({ path }),
104
+ })
105
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
106
+ return res.json()
107
+ }
frontend/src/store/useStore.ts CHANGED
@@ -19,6 +19,9 @@ interface Store {
19
  taskDifficulty: Difficulty
20
  setTaskId: (id: string) => void
21
  setTaskDifficulty: (d: Difficulty) => void
 
 
 
22
 
23
  // Init / DB
24
  dbSeeded: boolean
@@ -97,17 +100,22 @@ export const useStore = create<Store>((set) => ({
97
  }),
98
 
99
  // Task
100
- taskId: 'easy',
101
  taskDifficulty: 'easy',
102
  setTaskId: (id) => set({ taskId: id }),
103
- setTaskDifficulty: (d) =>
 
104
  set({
105
  taskDifficulty: d,
106
- taskId: d,
107
  benchmarkResults:
108
  d === 'easy' ? EASY_QUERIES : d === 'medium' ? MEDIUM_QUERIES : HARD_QUERIES,
109
  overallScore: null,
110
- }),
 
 
 
 
111
 
112
  // Init
113
  dbSeeded: false,
 
19
  taskDifficulty: Difficulty
20
  setTaskId: (id: string) => void
21
  setTaskDifficulty: (d: Difficulty) => void
22
+ // DB
23
+ dbLabel: string
24
+ setDbLabel: (label: string) => void
25
 
26
  // Init / DB
27
  dbSeeded: boolean
 
100
  }),
101
 
102
  // Task
103
+ taskId: 'simple_queries',
104
  taskDifficulty: 'easy',
105
  setTaskId: (id) => set({ taskId: id }),
106
+ setTaskDifficulty: (d) => {
107
+ const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
108
  set({
109
  taskDifficulty: d,
110
+ taskId,
111
  benchmarkResults:
112
  d === 'easy' ? EASY_QUERIES : d === 'medium' ? MEDIUM_QUERIES : HARD_QUERIES,
113
  overallScore: null,
114
+ })
115
+ },
116
+ // DB
117
+ dbLabel: 'benchmark (built-in)',
118
+ setDbLabel: (label) => set({ dbLabel: label }),
119
 
120
  // Init
121
  dbSeeded: false,