cnn_visualizer / src /train.ts
Joel Woodfield
Add support for extra dense layers after flattening
bc03848
import * as tf from '@tensorflow/tfjs';
type LayerValue = string | number | tf.Variable | null | undefined;
interface LayerConfig {
type: string;
[key: string]: LayerValue;
}
export interface RunInfo {
[key: string]: unknown;
}
export interface TrainController {
isPaused: boolean;
stopRequested: boolean;
sampleIndex: number;
}
interface BatchData {
xs: tf.Tensor;
labels: tf.Tensor2D;
}
interface TestSample {
xs: tf.Tensor;
}
interface TrainingData {
trainSize: number;
imageSize: number;
numInputChannels: number;
nextTrainBatch(batchSize: number): BatchData;
getTestSample(index: number): TestSample;
}
export interface OptimizerParams {
learningRate: string;
batchSize: string;
epochs: string;
// sgd only
momentum?: string;
// adam only
beta1?: string;
beta2?: string;
epsilon?: string;
}
type BatchEndCallback = (
epoch: number,
batch: number,
loss: number,
info: RunInfo[],
) => void | Promise<void>;
function parseValue(raw: string): string | number {
if (raw.trim() === '') {
return raw;
}
const num = Number(raw);
if (!Number.isNaN(num)) {
return num;
}
return raw;
}
function parseArchitecture(text: string): LayerConfig[] {
const layers: LayerConfig[] = [];
const matches = text.match(/\[(.*?)\]/gs);
if (!matches) return layers;
for (const block of matches) {
const content = block.slice(1, -1).trim();
if (content.length === 0) continue;
const tokens = content.split(/\s+/);
if (tokens.length === 0) continue;
const type = tokens[0];
const layer: LayerConfig = { type };
for (let i = 1; i < tokens.length; ++i) {
const token = tokens[i];
const [rawKey, rawValue] = token.split('=', 2);
if (!rawKey || rawValue === undefined) continue;
const key = rawKey === 'activation' ? 'activationType' : rawKey;
layer[key] = parseValue(rawValue);
}
layers.push(layer);
}
return layers;
}
function getNumber(layer: LayerConfig, key: string): number {
const value = layer[key];
if (typeof value !== 'number') {
throw new Error(`Layer "${layer.type}" is missing numeric "${key}"`);
}
return value;
}
function getVariable(layer: LayerConfig, key: string): tf.Variable {
const value = layer[key];
if (!value || typeof value !== 'object' || !('dispose' in value)) {
throw new Error(`Layer "${layer.type}" is missing tensor "${key}"`);
}
return value as tf.Variable;
}
function getPadding(layer: LayerConfig): number | 'same' | 'valid' {
const padding = layer.padding;
if (padding === undefined) return 'valid';
if (typeof padding === 'number') return padding;
if (padding === 'same' || padding === 'valid') return padding;
throw new Error(`Layer "${layer.type}" has invalid padding "${String(padding)}"`);
}
function getFlatDim(out: tf.Tensor): number {
const [, h, w, c] = out.shape;
if (
typeof h !== 'number' ||
typeof w !== 'number' ||
typeof c !== 'number'
) {
throw new Error('Cannot flatten tensor with unknown shape');
}
return h * w * c;
}
export class Cnn {
architecture: LayerConfig[];
inChannels: number;
weights: tf.Variable[];
constructor(architecture: string, inChannels: number) {
this.architecture = parseArchitecture(architecture);
this.inChannels = inChannels;
this.weights = this.initWeights();
}
initWeights(): tf.Variable[] {
const weights: tf.Variable[] = [];
let inChannels = this.inChannels;
for (const layer of this.architecture) {
if (layer.type === 'conv2d') {
const kernel = getNumber(layer, 'kernel');
const filters = getNumber(layer, 'filters');
const shape: [number, number, number, number] = [
kernel,
kernel,
inChannels,
filters,
];
const layerWeights = tf.variable(
tf.randomUniform(
shape,
-Math.sqrt(1 / (kernel * kernel * inChannels)),
Math.sqrt(1 / (kernel * kernel * inChannels)),
),
);
weights.push(layerWeights);
layer.weights = layerWeights;
inChannels = filters;
} else if (layer.type === 'dense') {
layer.weights = null;
layer.biases = null;
}
}
return weights;
}
dispose(): void {
for (const layer of this.architecture) {
if (layer.type === 'conv2d') {
getVariable(layer, 'weights').dispose();
} else if (layer.type === 'dense') {
const weights = layer.weights;
const biases = layer.biases;
if (weights && typeof weights === 'object' && 'dispose' in weights) {
(weights as tf.Variable).dispose();
}
if (biases && typeof biases === 'object' && 'dispose' in biases) {
(biases as tf.Variable).dispose();
}
}
}
}
forward(x: tf.Tensor4D): tf.Tensor {
let out: tf.Tensor = x;
for (let i = 0; i < this.architecture.length; i += 1) {
const layer = this.architecture[i];
switch (layer.type) {
case 'conv2d': {
const layerWeights = getVariable(layer, 'weights');
const stride = getNumber(layer, 'stride');
const padding = getPadding(layer);
out = tf.conv2d(
out as tf.Tensor4D,
layerWeights as tf.Tensor4D,
stride,
padding,
);
if (layer.activationType === 'relu') {
out = out.relu();
}
break;
}
case 'maxpool': {
const size = getNumber(layer, 'size');
const stride = getNumber(layer, 'stride');
out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0);
break;
}
case 'flatten': {
const flatDim = getFlatDim(out);
out = out.reshape([-1, flatDim]);
const next = this.architecture[i + 1];
if (next?.type === 'dense' && next.weights === null) {
const units = getNumber(next, 'units');
next.weights = tf.variable(
tf.randomUniform(
[flatDim, units],
-Math.sqrt(1 / flatDim),
Math.sqrt(1 / flatDim),
),
);
next.biases = tf.variable(tf.zeros([units]));
}
break;
}
case 'dense': {
const denseWeights = getVariable(layer, 'weights');
const denseBiases = getVariable(layer, 'biases');
out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
denseBiases as tf.Tensor1D,
);
if (layer.activationType === 'relu') {
out = out.relu();
}
const next = this.architecture[i + 1];
if (next?.type === 'dense' && next.weights === null) {
const nextUnits = getNumber(next, 'units');
const currentUnits = getNumber(layer, 'units');
next.weights = tf.variable(
tf.randomUniform(
[currentUnits, nextUnits],
-Math.sqrt(1 / currentUnits),
Math.sqrt(1 / currentUnits),
),
);
next.biases = tf.variable(tf.zeros([nextUnits]));
}
break;
}
default:
break;
}
}
return out;
}
forwardWithInfo(x: tf.Tensor4D): { output: tf.Tensor; info: RunInfo[] } {
let out: tf.Tensor = x;
const info: RunInfo[] = [];
info.push({
type: 'input',
output: out.dataSync(),
shape: x.shape,
});
for (let i = 0; i < this.architecture.length; i += 1) {
const layer = this.architecture[i];
switch (layer.type) {
case 'conv2d': {
const layerWeights = getVariable(layer, 'weights');
const stride = getNumber(layer, 'stride');
const padding = getPadding(layer);
out = tf.conv2d(
out as tf.Tensor4D,
layerWeights as tf.Tensor4D,
stride,
padding,
);
if (layer.activationType === 'relu') {
out = out.relu();
}
info.push({
type: 'conv2d',
output: out.dataSync(),
kernels: layerWeights.dataSync(),
outputShape: out.shape,
kernelShape: layerWeights.shape,
stride,
padding,
activationType: layer.activationType,
});
break;
}
case 'maxpool': {
const size = getNumber(layer, 'size');
const stride = getNumber(layer, 'stride');
out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0);
info.push({
type: 'maxpool',
output: out.dataSync(),
shape: out.shape,
size,
stride,
});
break;
}
case 'flatten': {
const flatDim = getFlatDim(out);
out = out.reshape([-1, flatDim]);
info.push({
type: 'flatten',
output: out.dataSync(),
shape: out.shape,
});
const next = this.architecture[i + 1];
if (next?.type === 'dense' && next.weights === null) {
const units = getNumber(next, 'units');
next.weights = tf.variable(
tf.randomUniform(
[flatDim, units],
-Math.sqrt(1 / flatDim),
Math.sqrt(1 / flatDim),
),
);
next.biases = tf.variable(tf.zeros([units]));
}
break;
}
case 'dense': {
const denseWeights = getVariable(layer, 'weights');
const denseBiases = getVariable(layer, 'biases');
out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
denseBiases as tf.Tensor1D,
);
if (layer.activationType === 'relu') {
out = out.relu();
}
const next = this.architecture[i + 1];
if (next?.type === 'dense' && next.weights === null) {
const nextUnits = getNumber(next, 'units');
const currentUnits = getNumber(layer, 'units');
next.weights = tf.variable(
tf.randomUniform(
[currentUnits, nextUnits],
-Math.sqrt(1 / currentUnits),
Math.sqrt(1 / currentUnits),
),
);
next.biases = tf.variable(tf.zeros([nextUnits]));
}
info.push({
type: 'dense',
output: out.dataSync(),
weights: denseWeights.dataSync(),
biases: denseBiases.dataSync(),
outputShape: out.shape,
weightShape: denseWeights.shape,
biasShape: denseBiases.shape,
inputUnits: denseWeights.shape[0],
outputUnits: getNumber(layer, 'units'),
outputSize: getNumber(layer, 'units'),
inputSize: denseWeights.shape[0],
activationType: layer.activationType,
});
break;
}
default:
break;
}
}
return { output: out, info };
}
}
export async function train(
data: TrainingData,
model: Cnn,
optimizer: tf.Optimizer,
batchSize: number,
epochs: number,
controller: TrainController,
onBatchEnd: BatchEndCallback | null = null,
): Promise<void> {
const numBatches = Math.floor(data.trainSize / batchSize);
for (let epoch = 0; epoch < epochs; ++epoch) {
for (let b = 0; b < numBatches; ++b) {
if (controller.stopRequested) {
console.log('Training stopped');
return;
}
while (controller.isPaused) {
await tf.nextFrame();
}
const cost = optimizer.minimize(() => {
const batch = data.nextTrainBatch(batchSize);
const xs = batch.xs.reshape([
batchSize,
data.imageSize,
data.imageSize,
data.numInputChannels,
]) as tf.Tensor4D;
const preds = model.forward(xs);
return tf.losses.softmaxCrossEntropy(batch.labels, preds).mean();
}, true);
if (!cost) {
throw new Error('Optimizer did not return a loss tensor');
}
const lossVal = (await cost.data())[0];
cost.dispose();
const sample = data.getTestSample(controller.sampleIndex);
const { output, info } = model.forwardWithInfo(
sample.xs.reshape([
1,
data.imageSize,
data.imageSize,
data.numInputChannels,
]) as tf.Tensor4D,
);
const probs = tf.tidy(() => tf.softmax(output));
info.push({
type: 'output',
output: probs.dataSync(),
shape: probs.shape,
});
if (controller.stopRequested) {
console.log('Training stopped');
probs.dispose();
return;
}
if (onBatchEnd) {
await onBatchEnd(epoch, b, lossVal, info);
}
probs.dispose();
await tf.nextFrame();
}
}
console.log('Training complete');
}