import { useEffect, useState, useRef } from "react"; import * as tf from "@tensorflow/tfjs"; import Plot from "react-plotly.js"; import { MnistData } from "./mnist.js"; import { Cnn, train } from "./train.ts"; import type { TrainController, RunInfo, OptimizerParams } from "./train.ts"; import { Button, InputField, Tabs, Dropdown, Card } from "@elvis/ui"; import InfoViewer from "./InfoViewer.tsx"; const DEFAULT_ARCHITECTURE = `[conv2d filters=8 kernel=11 stride=1 padding=1 activation=relu] [maxpool size=2 stride=2] [flatten] [dense units=10 activation=softmax]`; const isFirefox = navigator.userAgent.toLowerCase().includes("firefox"); await tf.setBackend(isFirefox ? "cpu" : "webgl"); await tf.ready(); export default function NetworkVisualizer() { const [dataset, setDataset] = useState(null); useEffect(() => { const loadData = async () => { const data = new MnistData(); await data.load(); setDataset(data); console.log("dataset loaded"); } loadData(); }, []); // architecture states const [architecture, setArchitecture] = useState(DEFAULT_ARCHITECTURE); const [optimizerType, setOptimizerType] = useState('adam'); const [optimizerParams, setOptimizerParams] = useState({ learningRate: '0.001', beta1: '0.9', beta2: '0.999', epsilon: '1e-8', batchSize: '32', epochs: '5', }); const modelRef = useRef(null); const optimizerRef = useRef(null); function handleArchitectureChange(newArchitecture: string) { if (isTraining) { alert('Cannot change architecture while training is in progress.'); } else { setArchitecture(newArchitecture); } } function handleOptimizerChange(newOptimizerType: string, newOptimizerParams: OptimizerParams) { if (isTraining) { alert('Cannot change optimizer settings while training is in progress.'); } else { setOptimizerType(newOptimizerType); setOptimizerParams(newOptimizerParams); } } function handleSampleIndexChange() { trainController.current.sampleIndex += 1; updateTick(); } function resetModel() { if (!dataset) return; if (modelRef.current) { modelRef.current.dispose(); } const cnn = new Cnn(architecture, dataset.numInputChannels); modelRef.current = cnn; } function resetOptimizer() { if (optimizerType === 'adam') { const learningRate = parseFloat(optimizerParams.learningRate); const beta1 = parseFloat(optimizerParams.beta1 || "0.9"); const beta2 = parseFloat(optimizerParams.beta2 || "0.999"); const epsilon = parseFloat(optimizerParams.epsilon || "1e-8"); if (Number.isNaN(learningRate) || learningRate <= 0) { alert('Invalid learning rate for Adam optimizer.'); return; } if (Number.isNaN(beta1) || beta1 < 0) { alert('Invalid beta1 for Adam optimizer.'); return; } if (Number.isNaN(beta2) || beta2 < 0) { alert('Invalid beta2 for Adam optimizer.'); return; } if (Number.isNaN(epsilon) || epsilon <= 0) { alert('Invalid epsilon for Adam optimizer.'); return; } const opt = tf.train.adam(learningRate, beta1, beta2, epsilon); if (optimizerRef.current) { optimizerRef.current.dispose(); } optimizerRef.current = opt; } else if (optimizerType === 'sgd') { const learningRate = parseFloat(optimizerParams.learningRate); if (Number.isNaN(learningRate) || learningRate <= 0) { alert('Invalid learning rate for SGD optimizer.'); return; } const opt = tf.train.sgd(learningRate); if (optimizerRef.current) { optimizerRef.current.dispose(); } optimizerRef.current = opt; } else { alert(`Unsupported optimizer type: ${optimizerType}`); } } // reset & init model and optimizer useEffect(() => { resetModel(); resetOptimizer(); }, [architecture, optimizerType, optimizerParams, dataset]); // training states const [isTraining, setIsTraining] = useState(false); const lossesRef = useRef>([]); const trainController = useRef({ isPaused: false, stopRequested: false, sampleIndex: 0, }); const infoRef = useRef>([]); // render timing const [, setTick] = useState(0); function updateTick() { setTick((tick) => tick + 1); } async function startTraining() { if (!modelRef || !dataset || !optimizerRef || isTraining) { return; } setIsTraining(true); trainController.current.isPaused = false; trainController.current.stopRequested = false; const batchSize = parseFloat(optimizerParams.batchSize); if (Number.isNaN(batchSize) || batchSize <= 0) { alert('Invalid batch size.'); setIsTraining(false); return; } const epochs = parseFloat(optimizerParams.epochs); if (Number.isNaN(epochs) || epochs <= 0) { alert('Invalid number of epochs.'); setIsTraining(false); return; } let lastTickUpdate = 0; if (!modelRef.current) return; if (!optimizerRef.current) return; try { await train( dataset, modelRef.current, optimizerRef.current, batchSize, epochs, trainController.current, (_epoch, _batch, loss, info) => { // lossesRef.current.push({ epoch, batch, loss }); lossesRef.current.push(loss); console.log(loss); infoRef.current = info; // update tick every 50ms const now = performance.now(); if (now - lastTickUpdate > 50) { lastTickUpdate = now; updateTick(); } }, ); } finally { setIsTraining(false); trainController.current.isPaused = false; trainController.current.stopRequested = false; alert('Training finished.'); } } function handleStartTraining() { console.log('Starting training...'); // trainController updated in startTraining startTraining(); } function handlePauseTraining() { console.log('Pausing training...'); trainController.current.isPaused = true; } function handleContinueTraining() { console.log('Continuing training...'); trainController.current.isPaused = false; } function handleStopTraining() { console.log('Stopping training...'); trainController.current.stopRequested = true; trainController.current.isPaused = false; } async function waitUntilNotTraining() { return new Promise((resolve) => { function check() { if (!isTraining) { resolve(); } else { requestAnimationFrame(check); } } check(); }); } async function handleResetTraining() { console.log('Resetting training...'); handleStopTraining(); await waitUntilNotTraining(); console.log('Training stopped. Resetting model.'); lossesRef.current = []; infoRef.current = []; resetModel(); resetOptimizer(); updateTick(); } return (
); } interface TrainingViewerProps { isTraining: boolean; lossesRef: React.RefObject>; infoRef: React.RefObject>; handleSampleIndexChange: () => void; } function TrainingViewer({ isTraining, lossesRef, infoRef, handleSampleIndexChange, }: TrainingViewerProps) { return (

Training { isTraining ? "in progress" : "not in progress" }

i), y: lossesRef.current, mode: 'lines', type: 'scatter', }, ]} layout={{ xaxis: { title: { text: 'Training steps' } }, yaxis: { title: { text: 'Train loss' } }, margin: { t: 40, r: 40, b: 40, l: 40 }, }} className="w-full h-[320px]" config={{ responsive: true }} />
) } interface SidebarProps { architecture: string; onArchitectureChange: (newArchitecture: string) => void; optimizerType: string; optimizerParams: OptimizerParams; onOptimizerChange: (newOptimizerType: string, newOptimizerParams: OptimizerParams) => void; onStartTraining: () => void; onPauseTraining: () => void; onContinueTraining: () => void; onStopTraining: () => void; onResetTraining: () => void; } function Sidebar({ architecture, onArchitectureChange, optimizerType, optimizerParams, onOptimizerChange, onStartTraining, onPauseTraining, onContinueTraining, onStopTraining, onResetTraining, }: SidebarProps) { const tabs = ["Architecture", "Train"]; const [activeTab, setActiveTab] = useState(tabs[0]); const [architectureDraft, setArchitectureDraft] = useState(architecture); return (
{ isFirefox && (

Warning: This demo may be quite slow on Firefox.

)} { activeTab === "Architecture" && ( <>
) }