cnn_visualizer / src /InfoViewer.tsx
Joel Woodfield
Add support for extra dense layers after flattening
bc03848
import { useMemo, useState } from "react";
import type { RunInfo } from "./train.ts";
import { Button, Card } from "@elvis/ui";
type NumericArray = ArrayLike<number>;
type Shape4D = [number, number, number, number];
interface InputLayer {
type: "input";
output: NumericArray;
shape: Shape4D;
}
interface Conv2dLayer {
type: "conv2d";
output: NumericArray;
kernels: NumericArray;
outputShape: Shape4D;
kernelShape: [number, number, number, number];
stride: number;
padding: number | "same" | "valid";
activationType?: string;
}
interface MaxPoolLayer {
type: "maxpool";
output: NumericArray;
shape: Shape4D;
size: number;
stride: number;
}
interface FlattenLayer {
type: "flatten";
}
interface DenseLayer {
type: "dense";
inputUnits: number;
outputUnits: number;
activationType?: string;
}
interface OutputLayer {
type: "output";
output: NumericArray;
}
type LayerInfo =
| InputLayer
| Conv2dLayer
| MaxPoolLayer
| FlattenLayer
| DenseLayer
| OutputLayer;
interface InfoViewerProps {
info?: RunInfo[];
onSampleIndexChange: () => void;
}
interface InputViewerProps {
output: NumericArray;
shape: Shape4D;
onSampleIndexChange: () => void;
}
interface Conv2dLayerViewerProps {
layerIdx: number;
stride: number;
padding: number | "same" | "valid";
activationType?: string;
kernels: NumericArray;
output: NumericArray;
kernelShape: [number, number, number, number];
outputShape: Shape4D;
}
interface MaxPoolLayerViewerProps {
layerIdx: number;
stride: number;
size: number;
output: NumericArray;
shape: Shape4D;
}
interface OutputLayerViewerProps {
probs: NumericArray;
}
interface DenseLayerViewerProps {
inputUnits: number;
outputUnits: number;
activationType?: string;
}
function asLayerInfo(layer: RunInfo): LayerInfo | null {
if (typeof layer.type !== "string") {
return null;
}
switch (layer.type) {
case "input":
case "conv2d":
case "maxpool":
case "flatten":
case "dense":
case "output":
return layer as unknown as LayerInfo;
default:
return null;
}
}
function extractImage(output: NumericArray, h: number, w: number, cCount: number): ImageData {
const buffer = new Uint8ClampedArray(h * w * 4);
for (let i = 0; i < h; ++i) {
for (let j = 0; j < w; ++j) {
for (let c = 0; c < cCount; ++c) {
const val = output[i * w * cCount + j * cCount + c];
buffer[(i * w + j) * 4 + c] = val * 255;
}
for (let c = cCount; c < 3; ++c) {
buffer[(i * w + j) * 4 + c] = buffer[(i * w + j) * 4 + (cCount - 1)];
}
buffer[(i * w + j) * 4 + 3] = 255;
}
}
return new ImageData(buffer, w, h);
}
function extractKernels(
kernels: NumericArray,
selectedOutputChannel: number,
kh: number,
kw: number,
inC: number,
outC: number,
): ImageData[] {
let minVal = Infinity;
let maxVal = -Infinity;
for (let i = 0; i < kernels.length; ++i) {
const val = kernels[i];
if (val < minVal) minVal = val;
if (val > maxVal) maxVal = val;
}
const kernelImgDatas: ImageData[] = [];
for (let ic = 0; ic < inC; ++ic) {
const buffer = new Uint8ClampedArray(kh * kw * 4);
for (let i = 0; i < kh; ++i) {
for (let j = 0; j < kw; ++j) {
const val = kernels[i * kw * inC * outC + j * inC * outC + ic * outC + selectedOutputChannel];
const normVal = (val - minVal) / (maxVal - minVal + 1e-8);
const pixel = Math.round(normVal * 255);
buffer[(i * kw + j) * 4 + 0] = pixel;
buffer[(i * kw + j) * 4 + 1] = pixel;
buffer[(i * kw + j) * 4 + 2] = pixel;
buffer[(i * kw + j) * 4 + 3] = 255;
}
}
kernelImgDatas.push(new ImageData(buffer, kw, kh));
}
return kernelImgDatas;
}
function extractActivationMaps(activations: NumericArray, h: number, w: number, cCount: number): ImageData[] {
let minVal = Infinity;
let maxVal = -Infinity;
for (let i = 0; i < activations.length; ++i) {
const val = activations[i];
if (val < minVal) minVal = val;
if (val > maxVal) maxVal = val;
}
const activationImgDatas: ImageData[] = [];
for (let c = 0; c < cCount; ++c) {
const buffer = new Uint8ClampedArray(h * w * 4);
for (let i = 0; i < h; ++i) {
for (let j = 0; j < w; ++j) {
const val = activations[i * w * cCount + j * cCount + c];
const normVal = (val - minVal) / (maxVal - minVal + 1e-8);
const pixel = Math.round(normVal * 255);
buffer[(i * w + j) * 4 + 0] = pixel;
buffer[(i * w + j) * 4 + 1] = pixel;
buffer[(i * w + j) * 4 + 2] = pixel;
buffer[(i * w + j) * 4 + 3] = 255;
}
}
activationImgDatas.push(new ImageData(buffer, w, h));
}
return activationImgDatas;
}
function imgDataToSrc(imgData: ImageData): string {
const canvas = document.createElement("canvas");
canvas.width = imgData.width;
canvas.height = imgData.height;
const ctx = canvas.getContext("2d");
if (!ctx) return "";
ctx.putImageData(imgData, 0, 0);
return canvas.toDataURL();
}
function InputViewer({ output, shape, onSampleIndexChange }: InputViewerProps) {
const [, h, w, cCount] = shape;
const imgData = useMemo(() => extractImage(output, h, w, cCount), [output, h, w, cCount]);
const imgSrc = useMemo(() => imgDataToSrc(imgData), [imgData]);
return (
<Card>
<h3 className="text-lg font-semibold text-slate-900">Input Layer</h3>
<div className="mt-2 text-sm text-slate-700">
<strong>Input size:</strong> {imgData.width} x {imgData.height}
</div>
<h4 className="mt-3 text-sm font-medium text-slate-800">Sample input</h4>
<div className="mt-2 flex flex-wrap items-center gap-3">
{imgSrc && (
<img
src={imgSrc}
alt="Input sample"
className="h-24 w-24 rounded border border-slate-200 object-contain"
/>
)}
<Button label="New Sample" onClick={onSampleIndexChange} />
</div>
</Card>
);
}
function Conv2dLayerViewer({
layerIdx,
stride,
padding,
activationType,
kernels,
output,
kernelShape,
outputShape,
}: Conv2dLayerViewerProps) {
const [selectedChannel, setSelectedChannel] = useState<number | null>(null);
const [kh, kw, inC, outC] = kernelShape;
const [, h, w, cCount] = outputShape;
const kernelImgDatas = useMemo(() => {
if (selectedChannel == null) return null;
return extractKernels(kernels, selectedChannel, kh, kw, inC, outC);
}, [kernels, selectedChannel, kh, kw, inC, outC]);
const kernelSrcs = useMemo(() => kernelImgDatas?.map(imgDataToSrc) ?? null, [kernelImgDatas]);
const activations = useMemo(
() => extractActivationMaps(output, h, w, cCount),
[output, h, w, cCount],
);
const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]);
return (
<Card>
<h3 className="text-lg font-semibold text-slate-900">Convolution Layer</h3>
<div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
<div>
<strong>Kernel Size:</strong> {kh} x {kw}
</div>
<div>
<strong>Stride:</strong> {stride}
</div>
<div>
<strong>Padding:</strong> {padding}
</div>
<div>
<strong>Activation:</strong> {activationType ?? "none"}
</div>
<div>
<strong>Output channels:</strong> {cCount}
</div>
</div>
{selectedChannel != null && kernelSrcs && (
<>
<h4 className="mt-3 text-sm font-medium text-slate-800">
Kernel for output {selectedChannel} (min-max normalized)
</h4>
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
{kernelSrcs.map((src, idx) => (
<img
key={`${layerIdx}-${idx}-kernel`}
src={src}
alt={`Kernel ${idx}`}
className="h-24 w-24 rounded object-contain"
/>
))}
</div>
</>
)}
<h4 className="mt-3 text-sm font-medium text-slate-800">Activation Maps (min-max normalized)</h4>
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
{activationSrcs.map((src, idx) => (
<img
key={`${layerIdx}-${idx}-activation`}
src={src}
alt={`Activation Map ${idx}`}
onClick={() => setSelectedChannel(selectedChannel === idx ? null : idx)}
className={`h-24 w-24 cursor-pointer rounded border object-contain ${
selectedChannel === idx ? "border-lime-500 ring-2 ring-lime-300" : "border-slate-200"
}`}
/>
))}
</div>
</Card>
);
}
function MaxPoolLayerViewer({ layerIdx, stride, size, output, shape }: MaxPoolLayerViewerProps) {
const [, h, w, cCount] = shape;
const activations = useMemo(() => extractActivationMaps(output, h, w, cCount), [output, h, w, cCount]);
const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]);
return (
<Card>
<h3 className="text-lg font-semibold text-slate-900">MaxPool Layer</h3>
<div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
<div>
<strong>Pool Size:</strong> {size} x {size}
</div>
<div>
<strong>Stride:</strong> {stride}
</div>
</div>
<h4 className="mt-3 text-sm font-medium text-slate-800">Outputs (min-max normalized)</h4>
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
{activationSrcs.map((src, idx) => (
<img
key={`${layerIdx}-${idx}-output`}
src={src}
alt={`Output ${idx}`}
className="h-24 w-24 rounded border border-slate-200 object-contain"
/>
))}
</div>
</Card>
);
}
function OutputLayerViewer({ probs }: OutputLayerViewerProps) {
const numClasses = probs.length;
return (
<Card>
<h3 className="text-lg font-semibold text-slate-900">Output Layer (softmax)</h3>
<div className="mt-2 text-sm text-slate-700">
<strong>Number of Classes:</strong> {numClasses}
</div>
<h4 className="mt-3 text-sm font-medium text-slate-800">Class Probabilities</h4>
<div className="mt-2 grid grid-cols-2 gap-2 text-sm sm:grid-cols-3">
{Array.from({ length: numClasses }).map((_, i) => (
<div key={i} className="rounded border border-slate-200 bg-slate-50 px-2 py-1">
<strong>Class {i}:</strong> {Number(probs[i]).toFixed(2)}
</div>
))}
</div>
</Card>
);
}
function DenseLayerViewer({ inputUnits, outputUnits, activationType }: DenseLayerViewerProps) {
return (
<Card>
<h3 className="text-lg font-semibold text-slate-900">Dense Layer</h3>
<div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
<div>
<strong>Input Units:</strong> {inputUnits}
</div>
<div>
<strong>Output Units:</strong> {outputUnits}
</div>
<div>
<strong>Activation:</strong> {activationType ?? "none"}
</div>
</div>
</Card>
);
}
export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) {
const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]);
function renderLayer(layer: LayerInfo, idx: number) {
switch (layer.type) {
case "input":
return (
<InputViewer
key={idx}
output={layer.output}
shape={layer.shape}
onSampleIndexChange={onSampleIndexChange}
/>
);
case "conv2d":
return (
<Conv2dLayerViewer
key={idx}
layerIdx={idx}
stride={layer.stride}
padding={layer.padding}
activationType={layer.activationType}
kernels={layer.kernels}
output={layer.output}
kernelShape={layer.kernelShape}
outputShape={layer.outputShape}
/>
);
case "maxpool":
return (
<MaxPoolLayerViewer
key={idx}
layerIdx={idx}
stride={layer.stride}
size={layer.size}
output={layer.output}
shape={layer.shape}
/>
);
case "flatten":
return null;
case "dense":
return (
<DenseLayerViewer
key={idx}
inputUnits={layer.inputUnits}
outputUnits={layer.outputUnits}
activationType={layer.activationType}
/>
);
case "output":
return null;
default:
return (
<Card key={idx} className="p-4 text-sm text-slate-700 shadow-sm">
Unknown Layer Type
</Card>
);
}
}
const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null;
const outputLayer = lastLayer?.type === "output" ? lastLayer : null;
// the second last layer in layers is the dense layer for the output - don't show it.
const bodyLayers = outputLayer ? layers.slice(0, -2) : layers;
return (
<div className="flex flex-col gap-4 min-h-0">
{bodyLayers.map((layer, idx) => renderLayer(layer, idx))}
{outputLayer && <OutputLayerViewer probs={outputLayer.output} />}
</div>
);
}