Spaces:
Sleeping
Sleeping
| import { create } from 'zustand' | |
| import type { | |
| ChatMessage, | |
| BenchmarkResult, | |
| RLState, | |
| TableInfo, | |
| SchemaGraph, | |
| PromptSnapshot, | |
| Difficulty, | |
| } from '../lib/types' | |
| import { fetchBenchmarkQuestions } from '../lib/api' | |
| interface Store { | |
| // Theme | |
| theme: 'dark' | 'light' | |
| toggleTheme: () => void | |
| // Task | |
| taskId: string | |
| taskDifficulty: Difficulty | |
| setTaskId: (id: string) => void | |
| setTaskDifficulty: (d: Difficulty) => void | |
| // DB | |
| dbLabel: string | |
| setDbLabel: (label: string) => void | |
| isCustomDb: boolean | |
| setIsCustomDb: (v: boolean) => void | |
| customDbSuggestions: string[] | |
| setCustomDbSuggestions: (qs: string[]) => void | |
| suggestionsLoading: boolean | |
| setSuggestionsLoading: (v: boolean) => void | |
| // Init / DB | |
| dbSeeded: boolean | |
| setDbSeeded: (v: boolean) => void | |
| tables: TableInfo[] | |
| setTables: (tables: TableInfo[]) => void | |
| schemaGraph: SchemaGraph | null | |
| setSchemaGraph: (g: SchemaGraph) => void | |
| // Chat | |
| messages: ChatMessage[] | |
| addMessage: (msg: ChatMessage) => void | |
| updateMessage: (id: string, update: Partial<ChatMessage>) => void | |
| clearMessages: () => void | |
| isExecuting: boolean | |
| setIsExecuting: (v: boolean) => void | |
| optimizingBanner: boolean | |
| setOptimizingBanner: (v: boolean) => void | |
| // Benchmark | |
| benchmarkResults: BenchmarkResult[] | |
| setBenchmarkResults: (r: BenchmarkResult[]) => void | |
| updateBenchmarkResult: (r: BenchmarkResult) => void | |
| resetBenchmark: () => void | |
| isBenchmarking: boolean | |
| setIsBenchmarking: (v: boolean) => void | |
| activeBenchmarkId: string | null | |
| setActiveBenchmarkId: (id: string | null) => void | |
| overallScore: number | null | |
| setOverallScore: (s: number) => void | |
| // RL State | |
| rlState: RLState | null | |
| setRlState: (s: RLState) => void | |
| // GEPA / Prompt | |
| currentPrompt: string | |
| promptGeneration: number | |
| promptHistory: PromptSnapshot[] | |
| setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void | |
| } | |
| function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult { | |
| return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null } | |
| } | |
| const PLACEHOLDER_QUERIES: BenchmarkResult[] = [ | |
| makePending('loading', 'Loading questions…', 'easy'), | |
| ] | |
| export const useStore = create<Store>((set) => ({ | |
| // Theme | |
| theme: 'dark', | |
| toggleTheme: () => | |
| set((s) => { | |
| const next = s.theme === 'dark' ? 'light' : 'dark' | |
| document.documentElement.setAttribute('data-theme', next) | |
| try { localStorage.setItem('theme', next) } catch { /* noop */ } | |
| return { theme: next } | |
| }), | |
| // Task | |
| taskId: 'simple_queries', | |
| taskDifficulty: 'easy', | |
| setTaskId: (id) => set({ taskId: id }), | |
| setTaskDifficulty: (d) => { | |
| const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries' | |
| set({ taskDifficulty: d, taskId, overallScore: null }) | |
| fetchBenchmarkQuestions(d) | |
| .then(({ questions }) => { | |
| set({ | |
| benchmarkResults: questions.map((q) => | |
| makePending(q.id, q.question, q.difficulty as Difficulty) | |
| ), | |
| }) | |
| }) | |
| .catch(() => { /* keep current list on error */ }) | |
| }, | |
| // DB | |
| dbLabel: 'benchmark (built-in)', | |
| setDbLabel: (label) => set({ dbLabel: label }), | |
| isCustomDb: false, | |
| setIsCustomDb: (v) => set({ isCustomDb: v }), | |
| customDbSuggestions: [], | |
| setCustomDbSuggestions: (qs) => set({ customDbSuggestions: qs }), | |
| suggestionsLoading: false, | |
| setSuggestionsLoading: (v) => set({ suggestionsLoading: v }), | |
| // Init | |
| dbSeeded: false, | |
| setDbSeeded: (v) => set({ dbSeeded: v }), | |
| tables: [], | |
| setTables: (tables) => set({ tables }), | |
| schemaGraph: null, | |
| setSchemaGraph: (g) => set({ schemaGraph: g }), | |
| // Chat | |
| messages: [], | |
| addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })), | |
| updateMessage: (id, update) => | |
| set((s) => ({ | |
| messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)), | |
| })), | |
| clearMessages: () => set({ messages: [] }), | |
| isExecuting: false, | |
| setIsExecuting: (v) => set({ isExecuting: v }), | |
| optimizingBanner: false, | |
| setOptimizingBanner: (v) => set({ optimizingBanner: v }), | |
| // Benchmark | |
| benchmarkResults: PLACEHOLDER_QUERIES, | |
| setBenchmarkResults: (r) => set({ benchmarkResults: r }), | |
| updateBenchmarkResult: (r) => | |
| set((s) => ({ | |
| benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)), | |
| })), | |
| resetBenchmark: () => | |
| set((s) => ({ | |
| benchmarkResults: s.benchmarkResults.map((r) => ({ | |
| ...r, | |
| status: 'pending' as const, | |
| score: null, | |
| sql: null, | |
| reason: null, | |
| attempts: null, | |
| refRowCount: null, | |
| agentRowCount: null, | |
| })), | |
| overallScore: null, | |
| })), | |
| isBenchmarking: false, | |
| setIsBenchmarking: (v) => set({ isBenchmarking: v }), | |
| activeBenchmarkId: null, | |
| setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }), | |
| overallScore: null, | |
| setOverallScore: (s) => set({ overallScore: s }), | |
| // RL State | |
| rlState: null, | |
| setRlState: (s) => set({ rlState: s }), | |
| // GEPA | |
| currentPrompt: '', | |
| promptGeneration: 0, | |
| promptHistory: [], | |
| setPromptData: (data) => | |
| set({ | |
| currentPrompt: data.prompt, | |
| promptGeneration: data.generation, | |
| promptHistory: data.history, | |
| }), | |
| })) | |