Spaces:
Running
Running
| import { useMemo, useState } from "react"; | |
| import type { RunInfo } from "./train.ts"; | |
| import { Button, Card } from "@elvis/ui"; | |
| type NumericArray = ArrayLike<number>; | |
| type Shape4D = [number, number, number, number]; | |
| interface InputLayer { | |
| type: "input"; | |
| output: NumericArray; | |
| shape: Shape4D; | |
| } | |
| interface Conv2dLayer { | |
| type: "conv2d"; | |
| output: NumericArray; | |
| kernels: NumericArray; | |
| outputShape: Shape4D; | |
| kernelShape: [number, number, number, number]; | |
| stride: number; | |
| padding: number | "same" | "valid"; | |
| activationType?: string; | |
| } | |
| interface MaxPoolLayer { | |
| type: "maxpool"; | |
| output: NumericArray; | |
| shape: Shape4D; | |
| size: number; | |
| stride: number; | |
| } | |
| interface FlattenLayer { | |
| type: "flatten"; | |
| } | |
| interface DenseLayer { | |
| type: "dense"; | |
| inputUnits: number; | |
| outputUnits: number; | |
| activationType?: string; | |
| } | |
| interface OutputLayer { | |
| type: "output"; | |
| output: NumericArray; | |
| } | |
| type LayerInfo = | |
| | InputLayer | |
| | Conv2dLayer | |
| | MaxPoolLayer | |
| | FlattenLayer | |
| | DenseLayer | |
| | OutputLayer; | |
| interface InfoViewerProps { | |
| info?: RunInfo[]; | |
| onSampleIndexChange: () => void; | |
| } | |
| interface InputViewerProps { | |
| output: NumericArray; | |
| shape: Shape4D; | |
| onSampleIndexChange: () => void; | |
| } | |
| interface Conv2dLayerViewerProps { | |
| layerIdx: number; | |
| stride: number; | |
| padding: number | "same" | "valid"; | |
| activationType?: string; | |
| kernels: NumericArray; | |
| output: NumericArray; | |
| kernelShape: [number, number, number, number]; | |
| outputShape: Shape4D; | |
| } | |
| interface MaxPoolLayerViewerProps { | |
| layerIdx: number; | |
| stride: number; | |
| size: number; | |
| output: NumericArray; | |
| shape: Shape4D; | |
| } | |
| interface OutputLayerViewerProps { | |
| probs: NumericArray; | |
| } | |
| interface DenseLayerViewerProps { | |
| inputUnits: number; | |
| outputUnits: number; | |
| activationType?: string; | |
| } | |
| function asLayerInfo(layer: RunInfo): LayerInfo | null { | |
| if (typeof layer.type !== "string") { | |
| return null; | |
| } | |
| switch (layer.type) { | |
| case "input": | |
| case "conv2d": | |
| case "maxpool": | |
| case "flatten": | |
| case "dense": | |
| case "output": | |
| return layer as unknown as LayerInfo; | |
| default: | |
| return null; | |
| } | |
| } | |
| function extractImage(output: NumericArray, h: number, w: number, cCount: number): ImageData { | |
| const buffer = new Uint8ClampedArray(h * w * 4); | |
| for (let i = 0; i < h; ++i) { | |
| for (let j = 0; j < w; ++j) { | |
| for (let c = 0; c < cCount; ++c) { | |
| const val = output[i * w * cCount + j * cCount + c]; | |
| buffer[(i * w + j) * 4 + c] = val * 255; | |
| } | |
| for (let c = cCount; c < 3; ++c) { | |
| buffer[(i * w + j) * 4 + c] = buffer[(i * w + j) * 4 + (cCount - 1)]; | |
| } | |
| buffer[(i * w + j) * 4 + 3] = 255; | |
| } | |
| } | |
| return new ImageData(buffer, w, h); | |
| } | |
| function extractKernels( | |
| kernels: NumericArray, | |
| selectedOutputChannel: number, | |
| kh: number, | |
| kw: number, | |
| inC: number, | |
| outC: number, | |
| ): ImageData[] { | |
| let minVal = Infinity; | |
| let maxVal = -Infinity; | |
| for (let i = 0; i < kernels.length; ++i) { | |
| const val = kernels[i]; | |
| if (val < minVal) minVal = val; | |
| if (val > maxVal) maxVal = val; | |
| } | |
| const kernelImgDatas: ImageData[] = []; | |
| for (let ic = 0; ic < inC; ++ic) { | |
| const buffer = new Uint8ClampedArray(kh * kw * 4); | |
| for (let i = 0; i < kh; ++i) { | |
| for (let j = 0; j < kw; ++j) { | |
| const val = kernels[i * kw * inC * outC + j * inC * outC + ic * outC + selectedOutputChannel]; | |
| const normVal = (val - minVal) / (maxVal - minVal + 1e-8); | |
| const pixel = Math.round(normVal * 255); | |
| buffer[(i * kw + j) * 4 + 0] = pixel; | |
| buffer[(i * kw + j) * 4 + 1] = pixel; | |
| buffer[(i * kw + j) * 4 + 2] = pixel; | |
| buffer[(i * kw + j) * 4 + 3] = 255; | |
| } | |
| } | |
| kernelImgDatas.push(new ImageData(buffer, kw, kh)); | |
| } | |
| return kernelImgDatas; | |
| } | |
| function extractActivationMaps(activations: NumericArray, h: number, w: number, cCount: number): ImageData[] { | |
| let minVal = Infinity; | |
| let maxVal = -Infinity; | |
| for (let i = 0; i < activations.length; ++i) { | |
| const val = activations[i]; | |
| if (val < minVal) minVal = val; | |
| if (val > maxVal) maxVal = val; | |
| } | |
| const activationImgDatas: ImageData[] = []; | |
| for (let c = 0; c < cCount; ++c) { | |
| const buffer = new Uint8ClampedArray(h * w * 4); | |
| for (let i = 0; i < h; ++i) { | |
| for (let j = 0; j < w; ++j) { | |
| const val = activations[i * w * cCount + j * cCount + c]; | |
| const normVal = (val - minVal) / (maxVal - minVal + 1e-8); | |
| const pixel = Math.round(normVal * 255); | |
| buffer[(i * w + j) * 4 + 0] = pixel; | |
| buffer[(i * w + j) * 4 + 1] = pixel; | |
| buffer[(i * w + j) * 4 + 2] = pixel; | |
| buffer[(i * w + j) * 4 + 3] = 255; | |
| } | |
| } | |
| activationImgDatas.push(new ImageData(buffer, w, h)); | |
| } | |
| return activationImgDatas; | |
| } | |
| function imgDataToSrc(imgData: ImageData): string { | |
| const canvas = document.createElement("canvas"); | |
| canvas.width = imgData.width; | |
| canvas.height = imgData.height; | |
| const ctx = canvas.getContext("2d"); | |
| if (!ctx) return ""; | |
| ctx.putImageData(imgData, 0, 0); | |
| return canvas.toDataURL(); | |
| } | |
| function InputViewer({ output, shape, onSampleIndexChange }: InputViewerProps) { | |
| const [, h, w, cCount] = shape; | |
| const imgData = useMemo(() => extractImage(output, h, w, cCount), [output, h, w, cCount]); | |
| const imgSrc = useMemo(() => imgDataToSrc(imgData), [imgData]); | |
| return ( | |
| <Card> | |
| <h3 className="text-lg font-semibold text-slate-900">Input Layer</h3> | |
| <div className="mt-2 text-sm text-slate-700"> | |
| <strong>Input size:</strong> {imgData.width} x {imgData.height} | |
| </div> | |
| <h4 className="mt-3 text-sm font-medium text-slate-800">Sample input</h4> | |
| <div className="mt-2 flex flex-wrap items-center gap-3"> | |
| {imgSrc && ( | |
| <img | |
| src={imgSrc} | |
| alt="Input sample" | |
| className="h-24 w-24 rounded border border-slate-200 object-contain" | |
| /> | |
| )} | |
| <Button label="New Sample" onClick={onSampleIndexChange} /> | |
| </div> | |
| </Card> | |
| ); | |
| } | |
| function Conv2dLayerViewer({ | |
| layerIdx, | |
| stride, | |
| padding, | |
| activationType, | |
| kernels, | |
| output, | |
| kernelShape, | |
| outputShape, | |
| }: Conv2dLayerViewerProps) { | |
| const [selectedChannel, setSelectedChannel] = useState<number | null>(null); | |
| const [kh, kw, inC, outC] = kernelShape; | |
| const [, h, w, cCount] = outputShape; | |
| const kernelImgDatas = useMemo(() => { | |
| if (selectedChannel == null) return null; | |
| return extractKernels(kernels, selectedChannel, kh, kw, inC, outC); | |
| }, [kernels, selectedChannel, kh, kw, inC, outC]); | |
| const kernelSrcs = useMemo(() => kernelImgDatas?.map(imgDataToSrc) ?? null, [kernelImgDatas]); | |
| const activations = useMemo( | |
| () => extractActivationMaps(output, h, w, cCount), | |
| [output, h, w, cCount], | |
| ); | |
| const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]); | |
| return ( | |
| <Card> | |
| <h3 className="text-lg font-semibold text-slate-900">Convolution Layer</h3> | |
| <div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700"> | |
| <div> | |
| <strong>Kernel Size:</strong> {kh} x {kw} | |
| </div> | |
| <div> | |
| <strong>Stride:</strong> {stride} | |
| </div> | |
| <div> | |
| <strong>Padding:</strong> {padding} | |
| </div> | |
| <div> | |
| <strong>Activation:</strong> {activationType ?? "none"} | |
| </div> | |
| <div> | |
| <strong>Output channels:</strong> {cCount} | |
| </div> | |
| </div> | |
| {selectedChannel != null && kernelSrcs && ( | |
| <> | |
| <h4 className="mt-3 text-sm font-medium text-slate-800"> | |
| Kernel for output {selectedChannel} (min-max normalized) | |
| </h4> | |
| <div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8"> | |
| {kernelSrcs.map((src, idx) => ( | |
| <img | |
| key={`${layerIdx}-${idx}-kernel`} | |
| src={src} | |
| alt={`Kernel ${idx}`} | |
| className="h-24 w-24 rounded object-contain" | |
| /> | |
| ))} | |
| </div> | |
| </> | |
| )} | |
| <h4 className="mt-3 text-sm font-medium text-slate-800">Activation Maps (min-max normalized)</h4> | |
| <div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8"> | |
| {activationSrcs.map((src, idx) => ( | |
| <img | |
| key={`${layerIdx}-${idx}-activation`} | |
| src={src} | |
| alt={`Activation Map ${idx}`} | |
| onClick={() => setSelectedChannel(selectedChannel === idx ? null : idx)} | |
| className={`h-24 w-24 cursor-pointer rounded border object-contain ${ | |
| selectedChannel === idx ? "border-lime-500 ring-2 ring-lime-300" : "border-slate-200" | |
| }`} | |
| /> | |
| ))} | |
| </div> | |
| </Card> | |
| ); | |
| } | |
| function MaxPoolLayerViewer({ layerIdx, stride, size, output, shape }: MaxPoolLayerViewerProps) { | |
| const [, h, w, cCount] = shape; | |
| const activations = useMemo(() => extractActivationMaps(output, h, w, cCount), [output, h, w, cCount]); | |
| const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]); | |
| return ( | |
| <Card> | |
| <h3 className="text-lg font-semibold text-slate-900">MaxPool Layer</h3> | |
| <div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700"> | |
| <div> | |
| <strong>Pool Size:</strong> {size} x {size} | |
| </div> | |
| <div> | |
| <strong>Stride:</strong> {stride} | |
| </div> | |
| </div> | |
| <h4 className="mt-3 text-sm font-medium text-slate-800">Outputs (min-max normalized)</h4> | |
| <div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8"> | |
| {activationSrcs.map((src, idx) => ( | |
| <img | |
| key={`${layerIdx}-${idx}-output`} | |
| src={src} | |
| alt={`Output ${idx}`} | |
| className="h-24 w-24 rounded border border-slate-200 object-contain" | |
| /> | |
| ))} | |
| </div> | |
| </Card> | |
| ); | |
| } | |
| function OutputLayerViewer({ probs }: OutputLayerViewerProps) { | |
| const numClasses = probs.length; | |
| return ( | |
| <Card> | |
| <h3 className="text-lg font-semibold text-slate-900">Output Layer (softmax)</h3> | |
| <div className="mt-2 text-sm text-slate-700"> | |
| <strong>Number of Classes:</strong> {numClasses} | |
| </div> | |
| <h4 className="mt-3 text-sm font-medium text-slate-800">Class Probabilities</h4> | |
| <div className="mt-2 grid grid-cols-2 gap-2 text-sm sm:grid-cols-3"> | |
| {Array.from({ length: numClasses }).map((_, i) => ( | |
| <div key={i} className="rounded border border-slate-200 bg-slate-50 px-2 py-1"> | |
| <strong>Class {i}:</strong> {Number(probs[i]).toFixed(2)} | |
| </div> | |
| ))} | |
| </div> | |
| </Card> | |
| ); | |
| } | |
| function DenseLayerViewer({ inputUnits, outputUnits, activationType }: DenseLayerViewerProps) { | |
| return ( | |
| <Card> | |
| <h3 className="text-lg font-semibold text-slate-900">Dense Layer</h3> | |
| <div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700"> | |
| <div> | |
| <strong>Input Units:</strong> {inputUnits} | |
| </div> | |
| <div> | |
| <strong>Output Units:</strong> {outputUnits} | |
| </div> | |
| <div> | |
| <strong>Activation:</strong> {activationType ?? "none"} | |
| </div> | |
| </div> | |
| </Card> | |
| ); | |
| } | |
| export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) { | |
| const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]); | |
| function renderLayer(layer: LayerInfo, idx: number) { | |
| switch (layer.type) { | |
| case "input": | |
| return ( | |
| <InputViewer | |
| key={idx} | |
| output={layer.output} | |
| shape={layer.shape} | |
| onSampleIndexChange={onSampleIndexChange} | |
| /> | |
| ); | |
| case "conv2d": | |
| return ( | |
| <Conv2dLayerViewer | |
| key={idx} | |
| layerIdx={idx} | |
| stride={layer.stride} | |
| padding={layer.padding} | |
| activationType={layer.activationType} | |
| kernels={layer.kernels} | |
| output={layer.output} | |
| kernelShape={layer.kernelShape} | |
| outputShape={layer.outputShape} | |
| /> | |
| ); | |
| case "maxpool": | |
| return ( | |
| <MaxPoolLayerViewer | |
| key={idx} | |
| layerIdx={idx} | |
| stride={layer.stride} | |
| size={layer.size} | |
| output={layer.output} | |
| shape={layer.shape} | |
| /> | |
| ); | |
| case "flatten": | |
| return null; | |
| case "dense": | |
| return ( | |
| <DenseLayerViewer | |
| key={idx} | |
| inputUnits={layer.inputUnits} | |
| outputUnits={layer.outputUnits} | |
| activationType={layer.activationType} | |
| /> | |
| ); | |
| case "output": | |
| return null; | |
| default: | |
| return ( | |
| <Card key={idx} className="p-4 text-sm text-slate-700 shadow-sm"> | |
| Unknown Layer Type | |
| </Card> | |
| ); | |
| } | |
| } | |
| const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null; | |
| const outputLayer = lastLayer?.type === "output" ? lastLayer : null; | |
| // the second last layer in layers is the dense layer for the output - don't show it. | |
| const bodyLayers = outputLayer ? layers.slice(0, -2) : layers; | |
| return ( | |
| <div className="flex flex-col gap-4 min-h-0"> | |
| {bodyLayers.map((layer, idx) => renderLayer(layer, idx))} | |
| {outputLayer && <OutputLayerViewer probs={outputLayer.output} />} | |
| </div> | |
| ); | |
| } | |