Joel Woodfield commited on
Commit
bc03848
·
1 Parent(s): e9d96d4

Add support for extra dense layers after flattening

Browse files
dist/assets/{index-BrJuj8_Q.js → index-C6SYjIbW.js} RENAMED
The diff for this file is too large to render. See raw diff
 
dist/index.html CHANGED
@@ -5,7 +5,7 @@
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>cnn_visualizer2</title>
8
- <script type="module" crossorigin src="/assets/index-BrJuj8_Q.js"></script>
9
  <link rel="stylesheet" crossorigin href="/assets/index-DSnfkyCF.css">
10
  </head>
11
  <body>
 
5
  <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
  <title>cnn_visualizer2</title>
8
+ <script type="module" crossorigin src="/assets/index-C6SYjIbW.js"></script>
9
  <link rel="stylesheet" crossorigin href="/assets/index-DSnfkyCF.css">
10
  </head>
11
  <body>
src/InfoViewer.tsx CHANGED
@@ -36,7 +36,9 @@ interface FlattenLayer {
36
 
37
  interface DenseLayer {
38
  type: "dense";
39
- details?: string;
 
 
40
  }
41
 
42
  interface OutputLayer {
@@ -86,6 +88,12 @@ interface OutputLayerViewerProps {
86
  probs: NumericArray;
87
  }
88
 
 
 
 
 
 
 
89
  function asLayerInfo(layer: RunInfo): LayerInfo | null {
90
  if (typeof layer.type !== "string") {
91
  return null;
@@ -363,6 +371,25 @@ function OutputLayerViewer({ probs }: OutputLayerViewerProps) {
363
  );
364
  }
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) {
367
  const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]);
368
 
@@ -406,9 +433,12 @@ export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProp
406
  return null;
407
  case "dense":
408
  return (
409
- <Card key={idx} className="p-4 text-sm text-slate-700 shadow-sm">
410
- Dense Layer: {layer.details ?? "N/A"}
411
- </Card>
 
 
 
412
  );
413
  case "output":
414
  return null;
@@ -423,7 +453,9 @@ export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProp
423
 
424
  const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null;
425
  const outputLayer = lastLayer?.type === "output" ? lastLayer : null;
426
- const bodyLayers = outputLayer ? layers.slice(0, -1) : layers;
 
 
427
 
428
  return (
429
  <div className="flex flex-col gap-4 min-h-0">
 
36
 
37
  interface DenseLayer {
38
  type: "dense";
39
+ inputUnits: number;
40
+ outputUnits: number;
41
+ activationType?: string;
42
  }
43
 
44
  interface OutputLayer {
 
88
  probs: NumericArray;
89
  }
90
 
91
+ interface DenseLayerViewerProps {
92
+ inputUnits: number;
93
+ outputUnits: number;
94
+ activationType?: string;
95
+ }
96
+
97
  function asLayerInfo(layer: RunInfo): LayerInfo | null {
98
  if (typeof layer.type !== "string") {
99
  return null;
 
371
  );
372
  }
373
 
374
+ function DenseLayerViewer({ inputUnits, outputUnits, activationType }: DenseLayerViewerProps) {
375
+ return (
376
+ <Card>
377
+ <h3 className="text-lg font-semibold text-slate-900">Dense Layer</h3>
378
+ <div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
379
+ <div>
380
+ <strong>Input Units:</strong> {inputUnits}
381
+ </div>
382
+ <div>
383
+ <strong>Output Units:</strong> {outputUnits}
384
+ </div>
385
+ <div>
386
+ <strong>Activation:</strong> {activationType ?? "none"}
387
+ </div>
388
+ </div>
389
+ </Card>
390
+ );
391
+ }
392
+
393
  export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) {
394
  const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]);
395
 
 
433
  return null;
434
  case "dense":
435
  return (
436
+ <DenseLayerViewer
437
+ key={idx}
438
+ inputUnits={layer.inputUnits}
439
+ outputUnits={layer.outputUnits}
440
+ activationType={layer.activationType}
441
+ />
442
  );
443
  case "output":
444
  return null;
 
453
 
454
  const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null;
455
  const outputLayer = lastLayer?.type === "output" ? lastLayer : null;
456
+
457
+ // the second last layer in layers is the dense layer for the output - don't show it.
458
+ const bodyLayers = outputLayer ? layers.slice(0, -2) : layers;
459
 
460
  return (
461
  <div className="flex flex-col gap-4 min-h-0">
src/train.ts CHANGED
@@ -247,6 +247,24 @@ export class Cnn {
247
  out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
248
  denseBiases as tf.Tensor1D,
249
  );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  break;
251
  }
252
  default:
@@ -340,6 +358,23 @@ export class Cnn {
340
  out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
341
  denseBiases as tf.Tensor1D,
342
  );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  info.push({
345
  type: 'dense',
@@ -349,7 +384,11 @@ export class Cnn {
349
  outputShape: out.shape,
350
  weightShape: denseWeights.shape,
351
  biasShape: denseBiases.shape,
352
- units: getNumber(layer, 'units'),
 
 
 
 
353
  });
354
  break;
355
  }
 
247
  out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
248
  denseBiases as tf.Tensor1D,
249
  );
250
+
251
+ if (layer.activationType === 'relu') {
252
+ out = out.relu();
253
+ }
254
+
255
+ const next = this.architecture[i + 1];
256
+ if (next?.type === 'dense' && next.weights === null) {
257
+ const nextUnits = getNumber(next, 'units');
258
+ const currentUnits = getNumber(layer, 'units');
259
+ next.weights = tf.variable(
260
+ tf.randomUniform(
261
+ [currentUnits, nextUnits],
262
+ -Math.sqrt(1 / currentUnits),
263
+ Math.sqrt(1 / currentUnits),
264
+ ),
265
+ );
266
+ next.biases = tf.variable(tf.zeros([nextUnits]));
267
+ }
268
  break;
269
  }
270
  default:
 
358
  out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
359
  denseBiases as tf.Tensor1D,
360
  );
361
+ if (layer.activationType === 'relu') {
362
+ out = out.relu();
363
+ }
364
+
365
+ const next = this.architecture[i + 1];
366
+ if (next?.type === 'dense' && next.weights === null) {
367
+ const nextUnits = getNumber(next, 'units');
368
+ const currentUnits = getNumber(layer, 'units');
369
+ next.weights = tf.variable(
370
+ tf.randomUniform(
371
+ [currentUnits, nextUnits],
372
+ -Math.sqrt(1 / currentUnits),
373
+ Math.sqrt(1 / currentUnits),
374
+ ),
375
+ );
376
+ next.biases = tf.variable(tf.zeros([nextUnits]));
377
+ }
378
 
379
  info.push({
380
  type: 'dense',
 
384
  outputShape: out.shape,
385
  weightShape: denseWeights.shape,
386
  biasShape: denseBiases.shape,
387
+ inputUnits: denseWeights.shape[0],
388
+ outputUnits: getNumber(layer, 'units'),
389
+ outputSize: getNumber(layer, 'units'),
390
+ inputSize: denseWeights.shape[0],
391
+ activationType: layer.activationType,
392
  });
393
  break;
394
  }