Neural-MRI / frontend /src /store /useModelStore.ts
Hiconcep's picture
Upload folder using huggingface_hub
0ce9643 verified
import { create } from 'zustand';
import type { ModelInfo } from '../types/model';
import { api } from '../api/client';
import type { ModelListEntry } from '../api/client';
import { useSAEStore } from './useSAEStore';
import { useScanStore } from './useScanStore';
import { useCompareStore } from './useCompareStore';
import { useCrossModelStore } from './useCrossModelStore';
import { useSettingsStore } from './useSettingsStore';
import { useModelSearchStore } from './useModelSearchStore';
interface ModelState {
modelInfo: ModelInfo | null;
isLoading: boolean;
error: string | null;
availableModels: ModelListEntry[];
loadModel: (modelId: string) => Promise<void>;
fetchModelInfo: () => Promise<void>;
fetchModels: () => Promise<void>;
}
export const useModelStore = create<ModelState>((set) => ({
modelInfo: null,
isLoading: false,
error: null,
availableModels: [],
loadModel: async (modelId) => {
// Guard against concurrent loads
if (useModelStore.getState().isLoading) return;
set({ isLoading: true, error: null });
try {
const device = useSettingsStore.getState().devicePreference;
const info = await api.model.load(modelId, device);
set({ modelInfo: info, isLoading: false });
// Clear stale data from previous model
useScanStore.getState().clearScanData();
useCompareStore.getState().clear();
useCrossModelStore.getState().clear();
// Register as recent model
useModelSearchStore.getState().addRecentModel(modelId);
// Refresh model list to update is_loaded flags
try {
const models = await api.model.list();
set({ availableModels: models });
} catch {
// non-critical
}
// Fetch SAE availability for the new model
useSAEStore.getState().reset();
useSAEStore.getState().fetchInfo();
} catch (e) {
set({ error: (e as Error).message, isLoading: false });
}
},
fetchModelInfo: async () => {
try {
const info = await api.model.info();
set({ modelInfo: info, error: null });
// Fetch SAE availability
useSAEStore.getState().fetchInfo();
} catch {
// Model not yet loaded — that's ok on initial load
}
},
fetchModels: async () => {
try {
const models = await api.model.list();
set({ availableModels: models });
} catch {
// Server may not be ready yet
}
},
}));