| | import { create } from 'zustand'; |
| | import { toast } from 'react-hot-toast'; |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | const API_BASE = import.meta.env.VITE_API_URL || '/api'; |
| |
|
| | |
| | export const useSystemStore = create((set, get) => ({ |
| | systemInfo: null, |
| | isLoading: false, |
| | error: null, |
| |
|
| | fetchSystemInfo: async () => { |
| | set({ isLoading: true, error: null }); |
| |
|
| | |
| | const cachedSystem = localStorage.getItem('system_info'); |
| | if (cachedSystem) { |
| | try { |
| | set({ systemInfo: JSON.parse(cachedSystem) }); |
| | } catch (e) { } |
| | } |
| |
|
| | try { |
| | const response = await fetch(`${API_BASE}/system/info`); |
| | if (!response.ok) throw new Error('Failed to fetch system info'); |
| | const data = await response.json(); |
| | set({ systemInfo: data, isLoading: false }); |
| | localStorage.setItem('system_info', JSON.stringify(data)); |
| | return data; |
| | } catch (error) { |
| | set({ error: error.message, isLoading: false }); |
| | return null; |
| | } |
| | }, |
| |
|
| | checkModelRequirements: async (paramsB, dtype = 'fp16') => { |
| | try { |
| | const response = await fetch( |
| | `${API_BASE}/system/check-model?model_params_billions=${paramsB}&dtype=${dtype}`, |
| | { method: 'POST' } |
| | ); |
| | return await response.json(); |
| | } catch (error) { |
| | return { can_load: false, error: error.message }; |
| | } |
| | } |
| | })); |
| |
|
| | |
| | export const useModelStore = create((set, get) => ({ |
| | modelInfo: null, |
| | layers: [], |
| | isLoading: false, |
| | loadingProgress: 0, |
| | error: null, |
| |
|
| | loadModel: async (modelName, options = {}) => { |
| | set({ isLoading: true, loadingProgress: 0, error: null }); |
| |
|
| | try { |
| | const response = await fetch(`${API_BASE}/models/load`, { |
| | method: 'POST', |
| | headers: { 'Content-Type': 'application/json' }, |
| | body: JSON.stringify({ |
| | model_name: modelName, |
| | model_type: options.modelType || 'generic', |
| | dtype: options.dtype || 'auto', |
| | device: options.device || 'auto', |
| | low_memory: options.lowMemory || false, |
| | trust_remote_code: true |
| | }) |
| | }); |
| |
|
| | const data = await response.json(); |
| |
|
| | if (data.success) { |
| | set({ |
| | modelInfo: data.model_info || { name: data.name || data.model_id }, |
| | isLoading: false, |
| | loadingProgress: 100 |
| | }); |
| | toast.success(`Model loaded: ${data.model_info?.name || modelName}`); |
| | return data; |
| | } else { |
| | const errMsg = data.error || 'Failed to load model'; |
| | set({ error: errMsg, isLoading: false }); |
| | toast.error(errMsg); |
| | return data; |
| | } |
| | } catch (error) { |
| | set({ error: error.message, isLoading: false }); |
| | toast.error(`Connection failed: ${error.message}`); |
| | return { success: false, error: error.message }; |
| | } |
| | }, |
| |
|
| | fetchModelInfo: async (modelName) => { |
| | set({ isLoading: true, error: null }); |
| | |
| | }, |
| |
|
| | checkLoadedModel: async () => { |
| | try { |
| | const response = await fetch(`${API_BASE}/models/info`); |
| | const data = await response.json(); |
| | if (data && data.name) { |
| | set({ |
| | modelInfo: { |
| | name: data.name, |
| | num_params: data.num_params, |
| | memory_mb: data.memory_mb, |
| | device: data.device, |
| | dtype: data.dtype |
| | }, |
| | error: null |
| | }); |
| | |
| | get().fetchLayers(); |
| | } else { |
| | set({ modelInfo: null, layers: [] }); |
| | } |
| | } catch (error) { |
| | console.error('Failed to check loaded model:', error); |
| | } |
| | }, |
| |
|
| | fetchLayers: async () => { |
| | try { |
| | const response = await fetch(`${API_BASE}/models/layers`); |
| | const data = await response.json(); |
| | |
| | if (data.quantizable_layers) { |
| | set({ layers: data.quantizable_layers }); |
| | } else if (data.layers) { |
| | |
| | set({ layers: data.layers.map(l => l.name) }); |
| | } |
| | } catch (error) { |
| | console.error('Failed to fetch layers:', error); |
| | } |
| | }, |
| |
|
| | unloadModel: async () => { |
| | try { |
| | await fetch(`${API_BASE}/models/unload`, { method: 'POST' }); |
| | set({ modelInfo: null, layers: [], error: null }); |
| | toast.success('Model unloaded'); |
| | } catch (error) { |
| | console.error('Failed to unload model:', error); |
| | } |
| | }, |
| |
|
| | clearError: () => set({ error: null }) |
| | })); |
| |
|
| | |
| | export const useQuantizationStore = create((set, get) => ({ |
| | result: null, |
| | isQuantizing: false, |
| | progress: 0, |
| | error: null, |
| | history: [], |
| |
|
| | quantizeWeights: async (config) => { |
| | set({ isQuantizing: true, progress: 0, error: null }); |
| |
|
| | try { |
| | const response = await fetch(`${API_BASE}/quantize/weights`, { |
| | method: 'POST', |
| | headers: { 'Content-Type': 'application/json' }, |
| | body: JSON.stringify({ |
| | in_features: config.inFeatures || 64, |
| | out_features: config.outFeatures || 128, |
| | bits: config.bits || 8, |
| | method: config.method || 'int8', |
| | mode: config.mode || 'symmetric', |
| | group_size: config.groupSize || null, |
| | weight_pattern: config.pattern || 'random', |
| | dtype: config.dtype || 'float32' |
| | }) |
| | }); |
| |
|
| | const data = await response.json(); |
| |
|
| | if (data.success) { |
| | set({ |
| | result: data, |
| | isQuantizing: false, |
| | progress: 100, |
| | history: [...get().history, { |
| | timestamp: new Date().toISOString(), |
| | config, |
| | stats: data.stats |
| | }] |
| | }); |
| | toast.success('Custom weights quantized'); |
| | return data; |
| | } else { |
| | set({ error: data.error || 'Quantization failed', isQuantizing: false }); |
| | toast.error(data.error || 'Quantization failed'); |
| | return data; |
| | } |
| | } catch (error) { |
| | set({ error: error.message, isQuantizing: false }); |
| | toast.error(error.message); |
| | return { success: false, error: error.message }; |
| | } |
| | }, |
| |
|
| | quantizeLayer: async (layerName, config) => { |
| | set({ isQuantizing: true, progress: 0, error: null }); |
| |
|
| | try { |
| | const response = await fetch(`${API_BASE}/quantize/layer`, { |
| | method: 'POST', |
| | headers: { 'Content-Type': 'application/json' }, |
| | body: JSON.stringify({ |
| | layer_name: layerName, |
| | bits: config.bits || 8, |
| | method: config.method || 'int8', |
| | mode: config.mode || 'symmetric', |
| | group_size: config.groupSize || null |
| | }) |
| | }); |
| |
|
| | const data = await response.json(); |
| |
|
| | if (data.success) { |
| | set({ |
| | result: data, |
| | error: null, |
| | isQuantizing: false, |
| | progress: 100 |
| | }); |
| | toast.success(`Layer ${layerName} quantized`); |
| | } else { |
| | const errMsg = data.error || 'Quantization failed'; |
| | set({ |
| | result: null, |
| | error: errMsg, |
| | isQuantizing: false |
| | }); |
| | toast.error(errMsg); |
| | } |
| | return data; |
| | } catch (error) { |
| | set({ error: error.message, isQuantizing: false }); |
| | toast.error(error.message); |
| | return { success: false, error: error.message }; |
| | } |
| | }, |
| |
|
| | quantizeModel: async (config) => { |
| | set({ isQuantizing: true, progress: 0, error: null }); |
| |
|
| | try { |
| | const response = await fetch(`${API_BASE}/quantize/model`, { |
| | method: 'POST', |
| | headers: { 'Content-Type': 'application/json' }, |
| | body: JSON.stringify({ |
| | bits: config.bits || 8, |
| | method: config.method || 'int8', |
| | mode: config.mode || 'symmetric', |
| | group_size: config.groupSize || null |
| | }) |
| | }); |
| |
|
| | if (data.success) { |
| | set({ |
| | result: data, |
| | error: null, |
| | isQuantizing: false, |
| | progress: 100 |
| | }); |
| | toast.success(`Full Model quantized! Saved ${data.summary?.total_memory_saved_mb?.toFixed(2)} MB`); |
| | } else { |
| | const errMsg = data.error || 'Optimization interrupted'; |
| | set({ |
| | result: null, |
| | error: errMsg, |
| | isQuantizing: false |
| | }); |
| | toast.error(errMsg); |
| | } |
| | return data; |
| | } catch (error) { |
| | set({ error: error.message, isQuantizing: false }); |
| | toast.error(error.message); |
| | return { success: false, error: error.message }; |
| | } |
| | }, |
| |
|
| | compareMethod: async (methods = ['int8', 'int4', 'nf4'], layerName = null) => { |
| | try { |
| | const body = { methods }; |
| | if (layerName) { |
| | body.layer_name = layerName; |
| | } |
| |
|
| | const response = await fetch(`${API_BASE}/analysis/compare`, { |
| | method: 'POST', |
| | headers: { 'Content-Type': 'application/json' }, |
| | body: JSON.stringify(body) |
| | }); |
| | return await response.json(); |
| | } catch (error) { |
| | toast.error(error.message); |
| | return { error: error.message }; |
| | } |
| | }, |
| |
|
| | clearResult: () => set({ result: null, error: null }), |
| | clearHistory: () => set({ history: [] }) |
| | })); |
| |
|
| | |
| | export const useUIStore = create((set) => ({ |
| | sidebarOpen: true, |
| | activeTab: 'quantize', |
| | theme: localStorage.getItem('theme') || 'dark', |
| |
|
| | toggleSidebar: () => set((state) => ({ sidebarOpen: !state.sidebarOpen })), |
| | setActiveTab: (tab) => set({ activeTab: tab }), |
| | toggleTheme: () => set((state) => { |
| | const newTheme = state.theme === 'dark' ? 'light' : 'dark'; |
| | document.documentElement.setAttribute('data-theme', newTheme); |
| | localStorage.setItem('theme', newTheme); |
| | return { theme: newTheme }; |
| | }) |
| | })); |
| |
|