import { create } from 'zustand' import type { ChatMessage, BenchmarkResult, RLState, TableInfo, SchemaGraph, PromptSnapshot, Difficulty, } from '../lib/types' import { fetchBenchmarkQuestions } from '../lib/api' interface Store { // Theme theme: 'dark' | 'light' toggleTheme: () => void // Task taskId: string taskDifficulty: Difficulty setTaskId: (id: string) => void setTaskDifficulty: (d: Difficulty) => void // DB dbLabel: string setDbLabel: (label: string) => void isCustomDb: boolean setIsCustomDb: (v: boolean) => void customDbSuggestions: string[] setCustomDbSuggestions: (qs: string[]) => void suggestionsLoading: boolean setSuggestionsLoading: (v: boolean) => void // Init / DB dbSeeded: boolean setDbSeeded: (v: boolean) => void tables: TableInfo[] setTables: (tables: TableInfo[]) => void schemaGraph: SchemaGraph | null setSchemaGraph: (g: SchemaGraph) => void // Chat messages: ChatMessage[] addMessage: (msg: ChatMessage) => void updateMessage: (id: string, update: Partial) => void clearMessages: () => void isExecuting: boolean setIsExecuting: (v: boolean) => void optimizingBanner: boolean setOptimizingBanner: (v: boolean) => void // Benchmark benchmarkResults: BenchmarkResult[] setBenchmarkResults: (r: BenchmarkResult[]) => void updateBenchmarkResult: (r: BenchmarkResult) => void resetBenchmark: () => void isBenchmarking: boolean setIsBenchmarking: (v: boolean) => void activeBenchmarkId: string | null setActiveBenchmarkId: (id: string | null) => void overallScore: number | null setOverallScore: (s: number) => void // RL State rlState: RLState | null setRlState: (s: RLState) => void // GEPA / Prompt currentPrompt: string promptGeneration: number promptHistory: PromptSnapshot[] setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void } function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult { return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null } } const PLACEHOLDER_QUERIES: BenchmarkResult[] = [ makePending('loading', 'Loading questions…', 'easy'), ] export const useStore = create((set) => ({ // Theme theme: 'dark', toggleTheme: () => set((s) => { const next = s.theme === 'dark' ? 'light' : 'dark' document.documentElement.setAttribute('data-theme', next) try { localStorage.setItem('theme', next) } catch { /* noop */ } return { theme: next } }), // Task taskId: 'simple_queries', taskDifficulty: 'easy', setTaskId: (id) => set({ taskId: id }), setTaskDifficulty: (d) => { const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries' set({ taskDifficulty: d, taskId, overallScore: null }) fetchBenchmarkQuestions(d) .then(({ questions }) => { set({ benchmarkResults: questions.map((q) => makePending(q.id, q.question, q.difficulty as Difficulty) ), }) }) .catch(() => { /* keep current list on error */ }) }, // DB dbLabel: 'benchmark (built-in)', setDbLabel: (label) => set({ dbLabel: label }), isCustomDb: false, setIsCustomDb: (v) => set({ isCustomDb: v }), customDbSuggestions: [], setCustomDbSuggestions: (qs) => set({ customDbSuggestions: qs }), suggestionsLoading: false, setSuggestionsLoading: (v) => set({ suggestionsLoading: v }), // Init dbSeeded: false, setDbSeeded: (v) => set({ dbSeeded: v }), tables: [], setTables: (tables) => set({ tables }), schemaGraph: null, setSchemaGraph: (g) => set({ schemaGraph: g }), // Chat messages: [], addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })), updateMessage: (id, update) => set((s) => ({ messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)), })), clearMessages: () => set({ messages: [] }), isExecuting: false, setIsExecuting: (v) => set({ isExecuting: v }), optimizingBanner: false, setOptimizingBanner: (v) => set({ optimizingBanner: v }), // Benchmark benchmarkResults: PLACEHOLDER_QUERIES, setBenchmarkResults: (r) => set({ benchmarkResults: r }), updateBenchmarkResult: (r) => set((s) => ({ benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)), })), resetBenchmark: () => set((s) => ({ benchmarkResults: s.benchmarkResults.map((r) => ({ ...r, status: 'pending' as const, score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null, })), overallScore: null, })), isBenchmarking: false, setIsBenchmarking: (v) => set({ isBenchmarking: v }), activeBenchmarkId: null, setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }), overallScore: null, setOverallScore: (s) => set({ overallScore: s }), // RL State rlState: null, setRlState: (s) => set({ rlState: s }), // GEPA currentPrompt: '', promptGeneration: 0, promptHistory: [], setPromptData: (data) => set({ currentPrompt: data.prompt, promptGeneration: data.generation, promptHistory: data.history, }), }))