cnn_visualizer / src /NetworkVisualizer.tsx
Joel Woodfield
Fix css styling issues
e9d96d4
import { useEffect, useState, useRef } from "react";
import * as tf from "@tensorflow/tfjs";
import Plot from "react-plotly.js";
import { MnistData } from "./mnist.js";
import { Cnn, train } from "./train.ts";
import type { TrainController, RunInfo, OptimizerParams } from "./train.ts";
import { Button, InputField, Tabs, Dropdown, Card } from "@elvis/ui";
import InfoViewer from "./InfoViewer.tsx";
const DEFAULT_ARCHITECTURE = `[conv2d filters=8 kernel=11
stride=1 padding=1 activation=relu]
[maxpool size=2 stride=2]
[flatten]
[dense units=10 activation=softmax]`;
const isFirefox = navigator.userAgent.toLowerCase().includes("firefox");
await tf.setBackend(isFirefox ? "cpu" : "webgl");
await tf.ready();
export default function NetworkVisualizer() {
const [dataset, setDataset] = useState<MnistData | null>(null);
useEffect(() => {
const loadData = async () => {
const data = new MnistData();
await data.load();
setDataset(data);
console.log("dataset loaded");
}
loadData();
}, []);
// architecture states
const [architecture, setArchitecture] = useState(DEFAULT_ARCHITECTURE);
const [optimizerType, setOptimizerType] = useState<string>('adam');
const [optimizerParams, setOptimizerParams] = useState<OptimizerParams>({
learningRate: '0.001',
beta1: '0.9',
beta2: '0.999',
epsilon: '1e-8',
batchSize: '32',
epochs: '5',
});
const modelRef = useRef<Cnn | null>(null);
const optimizerRef = useRef<tf.Optimizer | null>(null);
function handleArchitectureChange(newArchitecture: string) {
if (isTraining) {
alert('Cannot change architecture while training is in progress.');
} else {
setArchitecture(newArchitecture);
}
}
function handleOptimizerChange(newOptimizerType: string, newOptimizerParams: OptimizerParams) {
if (isTraining) {
alert('Cannot change optimizer settings while training is in progress.');
} else {
setOptimizerType(newOptimizerType);
setOptimizerParams(newOptimizerParams);
}
}
function handleSampleIndexChange() {
trainController.current.sampleIndex += 1;
updateTick();
}
function resetModel() {
if (!dataset) return;
if (modelRef.current) {
modelRef.current.dispose();
}
const cnn = new Cnn(architecture, dataset.numInputChannels);
modelRef.current = cnn;
}
function resetOptimizer() {
if (optimizerType === 'adam') {
const learningRate = parseFloat(optimizerParams.learningRate);
const beta1 = parseFloat(optimizerParams.beta1 || "0.9");
const beta2 = parseFloat(optimizerParams.beta2 || "0.999");
const epsilon = parseFloat(optimizerParams.epsilon || "1e-8");
if (Number.isNaN(learningRate) || learningRate <= 0) {
alert('Invalid learning rate for Adam optimizer.');
return;
}
if (Number.isNaN(beta1) || beta1 < 0) {
alert('Invalid beta1 for Adam optimizer.');
return;
}
if (Number.isNaN(beta2) || beta2 < 0) {
alert('Invalid beta2 for Adam optimizer.');
return;
}
if (Number.isNaN(epsilon) || epsilon <= 0) {
alert('Invalid epsilon for Adam optimizer.');
return;
}
const opt = tf.train.adam(learningRate, beta1, beta2, epsilon);
if (optimizerRef.current) {
optimizerRef.current.dispose();
}
optimizerRef.current = opt;
} else if (optimizerType === 'sgd') {
const learningRate = parseFloat(optimizerParams.learningRate);
if (Number.isNaN(learningRate) || learningRate <= 0) {
alert('Invalid learning rate for SGD optimizer.');
return;
}
const opt = tf.train.sgd(learningRate);
if (optimizerRef.current) {
optimizerRef.current.dispose();
}
optimizerRef.current = opt;
} else {
alert(`Unsupported optimizer type: ${optimizerType}`);
}
}
// reset & init model and optimizer
useEffect(() => {
resetModel();
resetOptimizer();
}, [architecture, optimizerType, optimizerParams, dataset]);
// training states
const [isTraining, setIsTraining] = useState<boolean>(false);
const lossesRef = useRef<Array<number>>([]);
const trainController = useRef<TrainController>({
isPaused: false,
stopRequested: false,
sampleIndex: 0,
});
const infoRef = useRef<Array<RunInfo>>([]);
// render timing
const [, setTick] = useState<number>(0);
function updateTick() {
setTick((tick) => tick + 1);
}
async function startTraining() {
if (!modelRef || !dataset || !optimizerRef || isTraining) {
return;
}
setIsTraining(true);
trainController.current.isPaused = false;
trainController.current.stopRequested = false;
const batchSize = parseFloat(optimizerParams.batchSize);
if (Number.isNaN(batchSize) || batchSize <= 0) {
alert('Invalid batch size.');
setIsTraining(false);
return;
}
const epochs = parseFloat(optimizerParams.epochs);
if (Number.isNaN(epochs) || epochs <= 0) {
alert('Invalid number of epochs.');
setIsTraining(false);
return;
}
let lastTickUpdate = 0;
if (!modelRef.current) return;
if (!optimizerRef.current) return;
try {
await train(
dataset,
modelRef.current,
optimizerRef.current,
batchSize,
epochs,
trainController.current,
(_epoch, _batch, loss, info) => {
// lossesRef.current.push({ epoch, batch, loss });
lossesRef.current.push(loss);
console.log(loss);
infoRef.current = info;
// update tick every 50ms
const now = performance.now();
if (now - lastTickUpdate > 50) {
lastTickUpdate = now;
updateTick();
}
},
);
} finally {
setIsTraining(false);
trainController.current.isPaused = false;
trainController.current.stopRequested = false;
alert('Training finished.');
}
}
function handleStartTraining() {
console.log('Starting training...');
// trainController updated in startTraining
startTraining();
}
function handlePauseTraining() {
console.log('Pausing training...');
trainController.current.isPaused = true;
}
function handleContinueTraining() {
console.log('Continuing training...');
trainController.current.isPaused = false;
}
function handleStopTraining() {
console.log('Stopping training...');
trainController.current.stopRequested = true;
trainController.current.isPaused = false;
}
async function waitUntilNotTraining() {
return new Promise<void>((resolve) => {
function check() {
if (!isTraining) {
resolve();
} else {
requestAnimationFrame(check);
}
}
check();
});
}
async function handleResetTraining() {
console.log('Resetting training...');
handleStopTraining();
await waitUntilNotTraining();
console.log('Training stopped. Resetting model.');
lossesRef.current = [];
infoRef.current = [];
resetModel();
resetOptimizer();
updateTick();
}
return (
<div className="grid grid-cols-[2fr_1fr] min-h-0 h-full gap-12">
<TrainingViewer
isTraining={isTraining}
lossesRef={lossesRef}
infoRef={infoRef}
handleSampleIndexChange={handleSampleIndexChange}
/>
<Sidebar
architecture={architecture}
onArchitectureChange={handleArchitectureChange}
optimizerType={optimizerType}
optimizerParams={optimizerParams}
onOptimizerChange={handleOptimizerChange}
onStartTraining={handleStartTraining}
onPauseTraining={handlePauseTraining}
onContinueTraining={handleContinueTraining}
onStopTraining={handleStopTraining}
onResetTraining={handleResetTraining}
/>
</div>
);
}
interface TrainingViewerProps {
isTraining: boolean;
lossesRef: React.RefObject<Array<number>>;
infoRef: React.RefObject<Array<RunInfo>>;
handleSampleIndexChange: () => void;
}
function TrainingViewer({
isTraining,
lossesRef,
infoRef,
handleSampleIndexChange,
}: TrainingViewerProps) {
return (
<div className="flex flex-col h-full overflow-auto gap-4 w-full min-h-0">
<Card className="flex flex-col gap-4">
<p>Training { isTraining ? "in progress" : "not in progress" }</p>
<Plot
data={[
{
x: lossesRef.current.map((_, i) => i),
y: lossesRef.current,
mode: 'lines',
type: 'scatter',
},
]}
layout={{
xaxis: { title: { text: 'Training steps' } },
yaxis: { title: { text: 'Train loss' } },
margin: { t: 40, r: 40, b: 40, l: 40 },
}}
className="w-full h-[320px]"
config={{ responsive: true }}
/>
</Card>
<InfoViewer info={infoRef.current} onSampleIndexChange={handleSampleIndexChange} />
</div>
)
}
interface SidebarProps {
architecture: string;
onArchitectureChange: (newArchitecture: string) => void;
optimizerType: string;
optimizerParams: OptimizerParams;
onOptimizerChange: (newOptimizerType: string, newOptimizerParams: OptimizerParams) => void;
onStartTraining: () => void;
onPauseTraining: () => void;
onContinueTraining: () => void;
onStopTraining: () => void;
onResetTraining: () => void;
}
function Sidebar({
architecture,
onArchitectureChange,
optimizerType,
optimizerParams,
onOptimizerChange,
onStartTraining,
onPauseTraining,
onContinueTraining,
onStopTraining,
onResetTraining,
}: SidebarProps) {
const tabs = ["Architecture", "Train"];
const [activeTab, setActiveTab] = useState<string>(tabs[0]);
const [architectureDraft, setArchitectureDraft] = useState<string>(architecture);
return (
<Card className="flex flex-col h-full p-4 gap-2 overflow-auto">
<Tabs tabs={tabs} activeTab={activeTab} onChange={setActiveTab} />
<div className="flex flex-col p-4 gap-4 h-full overflow-auto">
{ isFirefox && (
<p className="text-red-500">
Warning: This demo may be quite slow on Firefox.
</p>
)}
{ activeTab === "Architecture" && (
<>
<InputField
label="Architecture"
value={architectureDraft}
onChange={setArchitectureDraft}
rows={15}
/>
<Button
label="Apply architecture"
onClick={() => onArchitectureChange(architectureDraft)}
/>
</>
)}
{ activeTab === "Train" && (
<>
<Dropdown
label="Optimizer"
options={["sgd", "adam"]}
activeOption={optimizerType}
onChange={(newOptimizerType) => onOptimizerChange(newOptimizerType, optimizerParams)}
/>
{ optimizerType === 'sgd' && (
<InputField
label="Learning Rate"
value={optimizerParams.learningRate}
onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})}
/>
)}
{ optimizerType === 'adam' && (
<>
<InputField
label="Learning Rate"
value={optimizerParams.learningRate}
onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})}
/>
<InputField
label="Beta 1"
value={optimizerParams.beta1}
onChange={(newBeta1) => onOptimizerChange(optimizerType, {...optimizerParams, beta1: newBeta1})}
/>
<InputField
label="Beta 2"
value={optimizerParams.beta2}
onChange={(newBeta2) => onOptimizerChange(optimizerType, {...optimizerParams, beta2: newBeta2})}
/>
<InputField
label="Epsilon"
value={optimizerParams.epsilon}
onChange={(newEpsilon) => onOptimizerChange(optimizerType, {...optimizerParams, epsilon: newEpsilon})}
/>
</>
)}
<InputField
label="Batch Size"
value={optimizerParams.batchSize}
onChange={(newBatchSize) => onOptimizerChange(optimizerType, {...optimizerParams, batchSize: newBatchSize})}
/>
<InputField
label="Epochs"
value={optimizerParams.epochs}
onChange={(newEpochs) => onOptimizerChange(optimizerType, {...optimizerParams, epochs: newEpochs})}
/>
<Button
label="Start training"
onClick={onStartTraining}
/>
<Button
label="Pause training"
onClick={onPauseTraining}
/>
<Button
label="Continue training"
onClick={onContinueTraining}
/>
<Button
label="Stop training"
onClick={onStopTraining}
/>
<Button
label="Reset training"
onClick={onResetTraining}
/>
</>
)}
</div>
</Card>
)
}