ar9avg's picture
fix
17e7bd7
import { create } from 'zustand'
import type {
ChatMessage,
BenchmarkResult,
RLState,
TableInfo,
SchemaGraph,
PromptSnapshot,
Difficulty,
} from '../lib/types'
import { fetchBenchmarkQuestions } from '../lib/api'
interface Store {
// Theme
theme: 'dark' | 'light'
toggleTheme: () => void
// Task
taskId: string
taskDifficulty: Difficulty
setTaskId: (id: string) => void
setTaskDifficulty: (d: Difficulty) => void
// DB
dbLabel: string
setDbLabel: (label: string) => void
isCustomDb: boolean
setIsCustomDb: (v: boolean) => void
customDbSuggestions: string[]
setCustomDbSuggestions: (qs: string[]) => void
suggestionsLoading: boolean
setSuggestionsLoading: (v: boolean) => void
// Init / DB
dbSeeded: boolean
setDbSeeded: (v: boolean) => void
tables: TableInfo[]
setTables: (tables: TableInfo[]) => void
schemaGraph: SchemaGraph | null
setSchemaGraph: (g: SchemaGraph) => void
// Chat
messages: ChatMessage[]
addMessage: (msg: ChatMessage) => void
updateMessage: (id: string, update: Partial<ChatMessage>) => void
clearMessages: () => void
isExecuting: boolean
setIsExecuting: (v: boolean) => void
optimizingBanner: boolean
setOptimizingBanner: (v: boolean) => void
// Benchmark
benchmarkResults: BenchmarkResult[]
setBenchmarkResults: (r: BenchmarkResult[]) => void
updateBenchmarkResult: (r: BenchmarkResult) => void
resetBenchmark: () => void
isBenchmarking: boolean
setIsBenchmarking: (v: boolean) => void
activeBenchmarkId: string | null
setActiveBenchmarkId: (id: string | null) => void
overallScore: number | null
setOverallScore: (s: number) => void
// RL State
rlState: RLState | null
setRlState: (s: RLState) => void
// GEPA / Prompt
currentPrompt: string
promptGeneration: number
promptHistory: PromptSnapshot[]
setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
}
function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult {
return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null }
}
const PLACEHOLDER_QUERIES: BenchmarkResult[] = [
makePending('loading', 'Loading questions…', 'easy'),
]
export const useStore = create<Store>((set) => ({
// Theme
theme: 'dark',
toggleTheme: () =>
set((s) => {
const next = s.theme === 'dark' ? 'light' : 'dark'
document.documentElement.setAttribute('data-theme', next)
try { localStorage.setItem('theme', next) } catch { /* noop */ }
return { theme: next }
}),
// Task
taskId: 'simple_queries',
taskDifficulty: 'easy',
setTaskId: (id) => set({ taskId: id }),
setTaskDifficulty: (d) => {
const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
set({ taskDifficulty: d, taskId, overallScore: null })
fetchBenchmarkQuestions(d)
.then(({ questions }) => {
set({
benchmarkResults: questions.map((q) =>
makePending(q.id, q.question, q.difficulty as Difficulty)
),
})
})
.catch(() => { /* keep current list on error */ })
},
// DB
dbLabel: 'benchmark (built-in)',
setDbLabel: (label) => set({ dbLabel: label }),
isCustomDb: false,
setIsCustomDb: (v) => set({ isCustomDb: v }),
customDbSuggestions: [],
setCustomDbSuggestions: (qs) => set({ customDbSuggestions: qs }),
suggestionsLoading: false,
setSuggestionsLoading: (v) => set({ suggestionsLoading: v }),
// Init
dbSeeded: false,
setDbSeeded: (v) => set({ dbSeeded: v }),
tables: [],
setTables: (tables) => set({ tables }),
schemaGraph: null,
setSchemaGraph: (g) => set({ schemaGraph: g }),
// Chat
messages: [],
addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })),
updateMessage: (id, update) =>
set((s) => ({
messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)),
})),
clearMessages: () => set({ messages: [] }),
isExecuting: false,
setIsExecuting: (v) => set({ isExecuting: v }),
optimizingBanner: false,
setOptimizingBanner: (v) => set({ optimizingBanner: v }),
// Benchmark
benchmarkResults: PLACEHOLDER_QUERIES,
setBenchmarkResults: (r) => set({ benchmarkResults: r }),
updateBenchmarkResult: (r) =>
set((s) => ({
benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)),
})),
resetBenchmark: () =>
set((s) => ({
benchmarkResults: s.benchmarkResults.map((r) => ({
...r,
status: 'pending' as const,
score: null,
sql: null,
reason: null,
attempts: null,
refRowCount: null,
agentRowCount: null,
})),
overallScore: null,
})),
isBenchmarking: false,
setIsBenchmarking: (v) => set({ isBenchmarking: v }),
activeBenchmarkId: null,
setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }),
overallScore: null,
setOverallScore: (s) => set({ overallScore: s }),
// RL State
rlState: null,
setRlState: (s) => set({ rlState: s }),
// GEPA
currentPrompt: '',
promptGeneration: 0,
promptHistory: [],
setPromptData: (data) =>
set({
currentPrompt: data.prompt,
promptGeneration: data.generation,
promptHistory: data.history,
}),
}))