Spaces:
Sleeping
Sleeping
File size: 5,389 Bytes
3c665d2 8ae8e0b 3c665d2 f0b682f 17e7bd7 3c665d2 8ae8e0b 3c665d2 8ae8e0b 3c665d2 f0b682f 3c665d2 f0b682f 8ae8e0b f0b682f 17e7bd7 3c665d2 8ae8e0b 3c665d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import { create } from 'zustand'
import type {
ChatMessage,
BenchmarkResult,
RLState,
TableInfo,
SchemaGraph,
PromptSnapshot,
Difficulty,
} from '../lib/types'
import { fetchBenchmarkQuestions } from '../lib/api'
interface Store {
// Theme
theme: 'dark' | 'light'
toggleTheme: () => void
// Task
taskId: string
taskDifficulty: Difficulty
setTaskId: (id: string) => void
setTaskDifficulty: (d: Difficulty) => void
// DB
dbLabel: string
setDbLabel: (label: string) => void
isCustomDb: boolean
setIsCustomDb: (v: boolean) => void
customDbSuggestions: string[]
setCustomDbSuggestions: (qs: string[]) => void
suggestionsLoading: boolean
setSuggestionsLoading: (v: boolean) => void
// Init / DB
dbSeeded: boolean
setDbSeeded: (v: boolean) => void
tables: TableInfo[]
setTables: (tables: TableInfo[]) => void
schemaGraph: SchemaGraph | null
setSchemaGraph: (g: SchemaGraph) => void
// Chat
messages: ChatMessage[]
addMessage: (msg: ChatMessage) => void
updateMessage: (id: string, update: Partial<ChatMessage>) => void
clearMessages: () => void
isExecuting: boolean
setIsExecuting: (v: boolean) => void
optimizingBanner: boolean
setOptimizingBanner: (v: boolean) => void
// Benchmark
benchmarkResults: BenchmarkResult[]
setBenchmarkResults: (r: BenchmarkResult[]) => void
updateBenchmarkResult: (r: BenchmarkResult) => void
resetBenchmark: () => void
isBenchmarking: boolean
setIsBenchmarking: (v: boolean) => void
activeBenchmarkId: string | null
setActiveBenchmarkId: (id: string | null) => void
overallScore: number | null
setOverallScore: (s: number) => void
// RL State
rlState: RLState | null
setRlState: (s: RLState) => void
// GEPA / Prompt
currentPrompt: string
promptGeneration: number
promptHistory: PromptSnapshot[]
setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
}
function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult {
return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null }
}
const PLACEHOLDER_QUERIES: BenchmarkResult[] = [
makePending('loading', 'Loading questions…', 'easy'),
]
export const useStore = create<Store>((set) => ({
// Theme
theme: 'dark',
toggleTheme: () =>
set((s) => {
const next = s.theme === 'dark' ? 'light' : 'dark'
document.documentElement.setAttribute('data-theme', next)
try { localStorage.setItem('theme', next) } catch { /* noop */ }
return { theme: next }
}),
// Task
taskId: 'simple_queries',
taskDifficulty: 'easy',
setTaskId: (id) => set({ taskId: id }),
setTaskDifficulty: (d) => {
const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
set({ taskDifficulty: d, taskId, overallScore: null })
fetchBenchmarkQuestions(d)
.then(({ questions }) => {
set({
benchmarkResults: questions.map((q) =>
makePending(q.id, q.question, q.difficulty as Difficulty)
),
})
})
.catch(() => { /* keep current list on error */ })
},
// DB
dbLabel: 'benchmark (built-in)',
setDbLabel: (label) => set({ dbLabel: label }),
isCustomDb: false,
setIsCustomDb: (v) => set({ isCustomDb: v }),
customDbSuggestions: [],
setCustomDbSuggestions: (qs) => set({ customDbSuggestions: qs }),
suggestionsLoading: false,
setSuggestionsLoading: (v) => set({ suggestionsLoading: v }),
// Init
dbSeeded: false,
setDbSeeded: (v) => set({ dbSeeded: v }),
tables: [],
setTables: (tables) => set({ tables }),
schemaGraph: null,
setSchemaGraph: (g) => set({ schemaGraph: g }),
// Chat
messages: [],
addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })),
updateMessage: (id, update) =>
set((s) => ({
messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)),
})),
clearMessages: () => set({ messages: [] }),
isExecuting: false,
setIsExecuting: (v) => set({ isExecuting: v }),
optimizingBanner: false,
setOptimizingBanner: (v) => set({ optimizingBanner: v }),
// Benchmark
benchmarkResults: PLACEHOLDER_QUERIES,
setBenchmarkResults: (r) => set({ benchmarkResults: r }),
updateBenchmarkResult: (r) =>
set((s) => ({
benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)),
})),
resetBenchmark: () =>
set((s) => ({
benchmarkResults: s.benchmarkResults.map((r) => ({
...r,
status: 'pending' as const,
score: null,
sql: null,
reason: null,
attempts: null,
refRowCount: null,
agentRowCount: null,
})),
overallScore: null,
})),
isBenchmarking: false,
setIsBenchmarking: (v) => set({ isBenchmarking: v }),
activeBenchmarkId: null,
setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }),
overallScore: null,
setOverallScore: (s) => set({ overallScore: s }),
// RL State
rlState: null,
setRlState: (s) => set({ rlState: s }),
// GEPA
currentPrompt: '',
promptGeneration: 0,
promptHistory: [],
setPromptData: (data) =>
set({
currentPrompt: data.prompt,
promptGeneration: data.generation,
promptHistory: data.history,
}),
}))
|