import * as tf from '@tensorflow/tfjs'; type LayerValue = string | number | tf.Variable | null | undefined; interface LayerConfig { type: string; [key: string]: LayerValue; } export interface RunInfo { [key: string]: unknown; } export interface TrainController { isPaused: boolean; stopRequested: boolean; sampleIndex: number; } interface BatchData { xs: tf.Tensor; labels: tf.Tensor2D; } interface TestSample { xs: tf.Tensor; } interface TrainingData { trainSize: number; imageSize: number; numInputChannels: number; nextTrainBatch(batchSize: number): BatchData; getTestSample(index: number): TestSample; } export interface OptimizerParams { learningRate: string; batchSize: string; epochs: string; // sgd only momentum?: string; // adam only beta1?: string; beta2?: string; epsilon?: string; } type BatchEndCallback = ( epoch: number, batch: number, loss: number, info: RunInfo[], ) => void | Promise; function parseValue(raw: string): string | number { if (raw.trim() === '') { return raw; } const num = Number(raw); if (!Number.isNaN(num)) { return num; } return raw; } function parseArchitecture(text: string): LayerConfig[] { const layers: LayerConfig[] = []; const matches = text.match(/\[(.*?)\]/gs); if (!matches) return layers; for (const block of matches) { const content = block.slice(1, -1).trim(); if (content.length === 0) continue; const tokens = content.split(/\s+/); if (tokens.length === 0) continue; const type = tokens[0]; const layer: LayerConfig = { type }; for (let i = 1; i < tokens.length; ++i) { const token = tokens[i]; const [rawKey, rawValue] = token.split('=', 2); if (!rawKey || rawValue === undefined) continue; const key = rawKey === 'activation' ? 'activationType' : rawKey; layer[key] = parseValue(rawValue); } layers.push(layer); } return layers; } function getNumber(layer: LayerConfig, key: string): number { const value = layer[key]; if (typeof value !== 'number') { throw new Error(`Layer "${layer.type}" is missing numeric "${key}"`); } return value; } function getVariable(layer: LayerConfig, key: string): tf.Variable { const value = layer[key]; if (!value || typeof value !== 'object' || !('dispose' in value)) { throw new Error(`Layer "${layer.type}" is missing tensor "${key}"`); } return value as tf.Variable; } function getPadding(layer: LayerConfig): number | 'same' | 'valid' { const padding = layer.padding; if (padding === undefined) return 'valid'; if (typeof padding === 'number') return padding; if (padding === 'same' || padding === 'valid') return padding; throw new Error(`Layer "${layer.type}" has invalid padding "${String(padding)}"`); } function getFlatDim(out: tf.Tensor): number { const [, h, w, c] = out.shape; if ( typeof h !== 'number' || typeof w !== 'number' || typeof c !== 'number' ) { throw new Error('Cannot flatten tensor with unknown shape'); } return h * w * c; } export class Cnn { architecture: LayerConfig[]; inChannels: number; weights: tf.Variable[]; constructor(architecture: string, inChannels: number) { this.architecture = parseArchitecture(architecture); this.inChannels = inChannels; this.weights = this.initWeights(); } initWeights(): tf.Variable[] { const weights: tf.Variable[] = []; let inChannels = this.inChannels; for (const layer of this.architecture) { if (layer.type === 'conv2d') { const kernel = getNumber(layer, 'kernel'); const filters = getNumber(layer, 'filters'); const shape: [number, number, number, number] = [ kernel, kernel, inChannels, filters, ]; const layerWeights = tf.variable( tf.randomUniform( shape, -Math.sqrt(1 / (kernel * kernel * inChannels)), Math.sqrt(1 / (kernel * kernel * inChannels)), ), ); weights.push(layerWeights); layer.weights = layerWeights; inChannels = filters; } else if (layer.type === 'dense') { layer.weights = null; layer.biases = null; } } return weights; } dispose(): void { for (const layer of this.architecture) { if (layer.type === 'conv2d') { getVariable(layer, 'weights').dispose(); } else if (layer.type === 'dense') { const weights = layer.weights; const biases = layer.biases; if (weights && typeof weights === 'object' && 'dispose' in weights) { (weights as tf.Variable).dispose(); } if (biases && typeof biases === 'object' && 'dispose' in biases) { (biases as tf.Variable).dispose(); } } } } forward(x: tf.Tensor4D): tf.Tensor { let out: tf.Tensor = x; for (let i = 0; i < this.architecture.length; i += 1) { const layer = this.architecture[i]; switch (layer.type) { case 'conv2d': { const layerWeights = getVariable(layer, 'weights'); const stride = getNumber(layer, 'stride'); const padding = getPadding(layer); out = tf.conv2d( out as tf.Tensor4D, layerWeights as tf.Tensor4D, stride, padding, ); if (layer.activationType === 'relu') { out = out.relu(); } break; } case 'maxpool': { const size = getNumber(layer, 'size'); const stride = getNumber(layer, 'stride'); out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0); break; } case 'flatten': { const flatDim = getFlatDim(out); out = out.reshape([-1, flatDim]); const next = this.architecture[i + 1]; if (next?.type === 'dense' && next.weights === null) { const units = getNumber(next, 'units'); next.weights = tf.variable( tf.randomUniform( [flatDim, units], -Math.sqrt(1 / flatDim), Math.sqrt(1 / flatDim), ), ); next.biases = tf.variable(tf.zeros([units])); } break; } case 'dense': { const denseWeights = getVariable(layer, 'weights'); const denseBiases = getVariable(layer, 'biases'); out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add( denseBiases as tf.Tensor1D, ); if (layer.activationType === 'relu') { out = out.relu(); } const next = this.architecture[i + 1]; if (next?.type === 'dense' && next.weights === null) { const nextUnits = getNumber(next, 'units'); const currentUnits = getNumber(layer, 'units'); next.weights = tf.variable( tf.randomUniform( [currentUnits, nextUnits], -Math.sqrt(1 / currentUnits), Math.sqrt(1 / currentUnits), ), ); next.biases = tf.variable(tf.zeros([nextUnits])); } break; } default: break; } } return out; } forwardWithInfo(x: tf.Tensor4D): { output: tf.Tensor; info: RunInfo[] } { let out: tf.Tensor = x; const info: RunInfo[] = []; info.push({ type: 'input', output: out.dataSync(), shape: x.shape, }); for (let i = 0; i < this.architecture.length; i += 1) { const layer = this.architecture[i]; switch (layer.type) { case 'conv2d': { const layerWeights = getVariable(layer, 'weights'); const stride = getNumber(layer, 'stride'); const padding = getPadding(layer); out = tf.conv2d( out as tf.Tensor4D, layerWeights as tf.Tensor4D, stride, padding, ); if (layer.activationType === 'relu') { out = out.relu(); } info.push({ type: 'conv2d', output: out.dataSync(), kernels: layerWeights.dataSync(), outputShape: out.shape, kernelShape: layerWeights.shape, stride, padding, activationType: layer.activationType, }); break; } case 'maxpool': { const size = getNumber(layer, 'size'); const stride = getNumber(layer, 'stride'); out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0); info.push({ type: 'maxpool', output: out.dataSync(), shape: out.shape, size, stride, }); break; } case 'flatten': { const flatDim = getFlatDim(out); out = out.reshape([-1, flatDim]); info.push({ type: 'flatten', output: out.dataSync(), shape: out.shape, }); const next = this.architecture[i + 1]; if (next?.type === 'dense' && next.weights === null) { const units = getNumber(next, 'units'); next.weights = tf.variable( tf.randomUniform( [flatDim, units], -Math.sqrt(1 / flatDim), Math.sqrt(1 / flatDim), ), ); next.biases = tf.variable(tf.zeros([units])); } break; } case 'dense': { const denseWeights = getVariable(layer, 'weights'); const denseBiases = getVariable(layer, 'biases'); out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add( denseBiases as tf.Tensor1D, ); if (layer.activationType === 'relu') { out = out.relu(); } const next = this.architecture[i + 1]; if (next?.type === 'dense' && next.weights === null) { const nextUnits = getNumber(next, 'units'); const currentUnits = getNumber(layer, 'units'); next.weights = tf.variable( tf.randomUniform( [currentUnits, nextUnits], -Math.sqrt(1 / currentUnits), Math.sqrt(1 / currentUnits), ), ); next.biases = tf.variable(tf.zeros([nextUnits])); } info.push({ type: 'dense', output: out.dataSync(), weights: denseWeights.dataSync(), biases: denseBiases.dataSync(), outputShape: out.shape, weightShape: denseWeights.shape, biasShape: denseBiases.shape, inputUnits: denseWeights.shape[0], outputUnits: getNumber(layer, 'units'), outputSize: getNumber(layer, 'units'), inputSize: denseWeights.shape[0], activationType: layer.activationType, }); break; } default: break; } } return { output: out, info }; } } export async function train( data: TrainingData, model: Cnn, optimizer: tf.Optimizer, batchSize: number, epochs: number, controller: TrainController, onBatchEnd: BatchEndCallback | null = null, ): Promise { const numBatches = Math.floor(data.trainSize / batchSize); for (let epoch = 0; epoch < epochs; ++epoch) { for (let b = 0; b < numBatches; ++b) { if (controller.stopRequested) { console.log('Training stopped'); return; } while (controller.isPaused) { await tf.nextFrame(); } const cost = optimizer.minimize(() => { const batch = data.nextTrainBatch(batchSize); const xs = batch.xs.reshape([ batchSize, data.imageSize, data.imageSize, data.numInputChannels, ]) as tf.Tensor4D; const preds = model.forward(xs); return tf.losses.softmaxCrossEntropy(batch.labels, preds).mean(); }, true); if (!cost) { throw new Error('Optimizer did not return a loss tensor'); } const lossVal = (await cost.data())[0]; cost.dispose(); const sample = data.getTestSample(controller.sampleIndex); const { output, info } = model.forwardWithInfo( sample.xs.reshape([ 1, data.imageSize, data.imageSize, data.numInputChannels, ]) as tf.Tensor4D, ); const probs = tf.tidy(() => tf.softmax(output)); info.push({ type: 'output', output: probs.dataSync(), shape: probs.shape, }); if (controller.stopRequested) { console.log('Training stopped'); probs.dispose(); return; } if (onBatchEnd) { await onBatchEnd(epoch, b, lossVal, info); } probs.dispose(); await tf.nextFrame(); } } console.log('Training complete'); }