import { useMemo, useState } from "react"; import type { RunInfo } from "./train.ts"; import { Button, Card } from "@elvis/ui"; type NumericArray = ArrayLike; type Shape4D = [number, number, number, number]; interface InputLayer { type: "input"; output: NumericArray; shape: Shape4D; } interface Conv2dLayer { type: "conv2d"; output: NumericArray; kernels: NumericArray; outputShape: Shape4D; kernelShape: [number, number, number, number]; stride: number; padding: number | "same" | "valid"; activationType?: string; } interface MaxPoolLayer { type: "maxpool"; output: NumericArray; shape: Shape4D; size: number; stride: number; } interface FlattenLayer { type: "flatten"; } interface DenseLayer { type: "dense"; inputUnits: number; outputUnits: number; activationType?: string; } interface OutputLayer { type: "output"; output: NumericArray; } type LayerInfo = | InputLayer | Conv2dLayer | MaxPoolLayer | FlattenLayer | DenseLayer | OutputLayer; interface InfoViewerProps { info?: RunInfo[]; onSampleIndexChange: () => void; } interface InputViewerProps { output: NumericArray; shape: Shape4D; onSampleIndexChange: () => void; } interface Conv2dLayerViewerProps { layerIdx: number; stride: number; padding: number | "same" | "valid"; activationType?: string; kernels: NumericArray; output: NumericArray; kernelShape: [number, number, number, number]; outputShape: Shape4D; } interface MaxPoolLayerViewerProps { layerIdx: number; stride: number; size: number; output: NumericArray; shape: Shape4D; } interface OutputLayerViewerProps { probs: NumericArray; } interface DenseLayerViewerProps { inputUnits: number; outputUnits: number; activationType?: string; } function asLayerInfo(layer: RunInfo): LayerInfo | null { if (typeof layer.type !== "string") { return null; } switch (layer.type) { case "input": case "conv2d": case "maxpool": case "flatten": case "dense": case "output": return layer as unknown as LayerInfo; default: return null; } } function extractImage(output: NumericArray, h: number, w: number, cCount: number): ImageData { const buffer = new Uint8ClampedArray(h * w * 4); for (let i = 0; i < h; ++i) { for (let j = 0; j < w; ++j) { for (let c = 0; c < cCount; ++c) { const val = output[i * w * cCount + j * cCount + c]; buffer[(i * w + j) * 4 + c] = val * 255; } for (let c = cCount; c < 3; ++c) { buffer[(i * w + j) * 4 + c] = buffer[(i * w + j) * 4 + (cCount - 1)]; } buffer[(i * w + j) * 4 + 3] = 255; } } return new ImageData(buffer, w, h); } function extractKernels( kernels: NumericArray, selectedOutputChannel: number, kh: number, kw: number, inC: number, outC: number, ): ImageData[] { let minVal = Infinity; let maxVal = -Infinity; for (let i = 0; i < kernels.length; ++i) { const val = kernels[i]; if (val < minVal) minVal = val; if (val > maxVal) maxVal = val; } const kernelImgDatas: ImageData[] = []; for (let ic = 0; ic < inC; ++ic) { const buffer = new Uint8ClampedArray(kh * kw * 4); for (let i = 0; i < kh; ++i) { for (let j = 0; j < kw; ++j) { const val = kernels[i * kw * inC * outC + j * inC * outC + ic * outC + selectedOutputChannel]; const normVal = (val - minVal) / (maxVal - minVal + 1e-8); const pixel = Math.round(normVal * 255); buffer[(i * kw + j) * 4 + 0] = pixel; buffer[(i * kw + j) * 4 + 1] = pixel; buffer[(i * kw + j) * 4 + 2] = pixel; buffer[(i * kw + j) * 4 + 3] = 255; } } kernelImgDatas.push(new ImageData(buffer, kw, kh)); } return kernelImgDatas; } function extractActivationMaps(activations: NumericArray, h: number, w: number, cCount: number): ImageData[] { let minVal = Infinity; let maxVal = -Infinity; for (let i = 0; i < activations.length; ++i) { const val = activations[i]; if (val < minVal) minVal = val; if (val > maxVal) maxVal = val; } const activationImgDatas: ImageData[] = []; for (let c = 0; c < cCount; ++c) { const buffer = new Uint8ClampedArray(h * w * 4); for (let i = 0; i < h; ++i) { for (let j = 0; j < w; ++j) { const val = activations[i * w * cCount + j * cCount + c]; const normVal = (val - minVal) / (maxVal - minVal + 1e-8); const pixel = Math.round(normVal * 255); buffer[(i * w + j) * 4 + 0] = pixel; buffer[(i * w + j) * 4 + 1] = pixel; buffer[(i * w + j) * 4 + 2] = pixel; buffer[(i * w + j) * 4 + 3] = 255; } } activationImgDatas.push(new ImageData(buffer, w, h)); } return activationImgDatas; } function imgDataToSrc(imgData: ImageData): string { const canvas = document.createElement("canvas"); canvas.width = imgData.width; canvas.height = imgData.height; const ctx = canvas.getContext("2d"); if (!ctx) return ""; ctx.putImageData(imgData, 0, 0); return canvas.toDataURL(); } function InputViewer({ output, shape, onSampleIndexChange }: InputViewerProps) { const [, h, w, cCount] = shape; const imgData = useMemo(() => extractImage(output, h, w, cCount), [output, h, w, cCount]); const imgSrc = useMemo(() => imgDataToSrc(imgData), [imgData]); return (

Input Layer

Input size: {imgData.width} x {imgData.height}

Sample input

{imgSrc && ( Input sample )}
); } function Conv2dLayerViewer({ layerIdx, stride, padding, activationType, kernels, output, kernelShape, outputShape, }: Conv2dLayerViewerProps) { const [selectedChannel, setSelectedChannel] = useState(null); const [kh, kw, inC, outC] = kernelShape; const [, h, w, cCount] = outputShape; const kernelImgDatas = useMemo(() => { if (selectedChannel == null) return null; return extractKernels(kernels, selectedChannel, kh, kw, inC, outC); }, [kernels, selectedChannel, kh, kw, inC, outC]); const kernelSrcs = useMemo(() => kernelImgDatas?.map(imgDataToSrc) ?? null, [kernelImgDatas]); const activations = useMemo( () => extractActivationMaps(output, h, w, cCount), [output, h, w, cCount], ); const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]); return (

Convolution Layer

Kernel Size: {kh} x {kw}
Stride: {stride}
Padding: {padding}
Activation: {activationType ?? "none"}
Output channels: {cCount}
{selectedChannel != null && kernelSrcs && ( <>

Kernel for output {selectedChannel} (min-max normalized)

{kernelSrcs.map((src, idx) => ( {`Kernel ))}
)}

Activation Maps (min-max normalized)

{activationSrcs.map((src, idx) => ( {`Activation setSelectedChannel(selectedChannel === idx ? null : idx)} className={`h-24 w-24 cursor-pointer rounded border object-contain ${ selectedChannel === idx ? "border-lime-500 ring-2 ring-lime-300" : "border-slate-200" }`} /> ))}
); } function MaxPoolLayerViewer({ layerIdx, stride, size, output, shape }: MaxPoolLayerViewerProps) { const [, h, w, cCount] = shape; const activations = useMemo(() => extractActivationMaps(output, h, w, cCount), [output, h, w, cCount]); const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]); return (

MaxPool Layer

Pool Size: {size} x {size}
Stride: {stride}

Outputs (min-max normalized)

{activationSrcs.map((src, idx) => ( {`Output ))}
); } function OutputLayerViewer({ probs }: OutputLayerViewerProps) { const numClasses = probs.length; return (

Output Layer (softmax)

Number of Classes: {numClasses}

Class Probabilities

{Array.from({ length: numClasses }).map((_, i) => (
Class {i}: {Number(probs[i]).toFixed(2)}
))}
); } function DenseLayerViewer({ inputUnits, outputUnits, activationType }: DenseLayerViewerProps) { return (

Dense Layer

Input Units: {inputUnits}
Output Units: {outputUnits}
Activation: {activationType ?? "none"}
); } export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) { const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]); function renderLayer(layer: LayerInfo, idx: number) { switch (layer.type) { case "input": return ( ); case "conv2d": return ( ); case "maxpool": return ( ); case "flatten": return null; case "dense": return ( ); case "output": return null; default: return ( Unknown Layer Type ); } } const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null; const outputLayer = lastLayer?.type === "output" ? lastLayer : null; // the second last layer in layers is the dense layer for the output - don't show it. const bodyLayers = outputLayer ? layers.slice(0, -2) : layers; return (
{bodyLayers.map((layer, idx) => renderLayer(layer, idx))} {outputLayer && }
); }