Spaces:
Running
Running
| import * as tf from '@tensorflow/tfjs'; | |
| type LayerValue = string | number | tf.Variable | null | undefined; | |
| interface LayerConfig { | |
| type: string; | |
| [key: string]: LayerValue; | |
| } | |
| export interface RunInfo { | |
| [key: string]: unknown; | |
| } | |
| export interface TrainController { | |
| isPaused: boolean; | |
| stopRequested: boolean; | |
| sampleIndex: number; | |
| } | |
| interface BatchData { | |
| xs: tf.Tensor; | |
| labels: tf.Tensor2D; | |
| } | |
| interface TestSample { | |
| xs: tf.Tensor; | |
| } | |
| interface TrainingData { | |
| trainSize: number; | |
| imageSize: number; | |
| numInputChannels: number; | |
| nextTrainBatch(batchSize: number): BatchData; | |
| getTestSample(index: number): TestSample; | |
| } | |
| export interface OptimizerParams { | |
| learningRate: string; | |
| batchSize: string; | |
| epochs: string; | |
| // sgd only | |
| momentum?: string; | |
| // adam only | |
| beta1?: string; | |
| beta2?: string; | |
| epsilon?: string; | |
| } | |
| type BatchEndCallback = ( | |
| epoch: number, | |
| batch: number, | |
| loss: number, | |
| info: RunInfo[], | |
| ) => void | Promise<void>; | |
| function parseValue(raw: string): string | number { | |
| if (raw.trim() === '') { | |
| return raw; | |
| } | |
| const num = Number(raw); | |
| if (!Number.isNaN(num)) { | |
| return num; | |
| } | |
| return raw; | |
| } | |
| function parseArchitecture(text: string): LayerConfig[] { | |
| const layers: LayerConfig[] = []; | |
| const matches = text.match(/\[(.*?)\]/gs); | |
| if (!matches) return layers; | |
| for (const block of matches) { | |
| const content = block.slice(1, -1).trim(); | |
| if (content.length === 0) continue; | |
| const tokens = content.split(/\s+/); | |
| if (tokens.length === 0) continue; | |
| const type = tokens[0]; | |
| const layer: LayerConfig = { type }; | |
| for (let i = 1; i < tokens.length; ++i) { | |
| const token = tokens[i]; | |
| const [rawKey, rawValue] = token.split('=', 2); | |
| if (!rawKey || rawValue === undefined) continue; | |
| const key = rawKey === 'activation' ? 'activationType' : rawKey; | |
| layer[key] = parseValue(rawValue); | |
| } | |
| layers.push(layer); | |
| } | |
| return layers; | |
| } | |
| function getNumber(layer: LayerConfig, key: string): number { | |
| const value = layer[key]; | |
| if (typeof value !== 'number') { | |
| throw new Error(`Layer "${layer.type}" is missing numeric "${key}"`); | |
| } | |
| return value; | |
| } | |
| function getVariable(layer: LayerConfig, key: string): tf.Variable { | |
| const value = layer[key]; | |
| if (!value || typeof value !== 'object' || !('dispose' in value)) { | |
| throw new Error(`Layer "${layer.type}" is missing tensor "${key}"`); | |
| } | |
| return value as tf.Variable; | |
| } | |
| function getPadding(layer: LayerConfig): number | 'same' | 'valid' { | |
| const padding = layer.padding; | |
| if (padding === undefined) return 'valid'; | |
| if (typeof padding === 'number') return padding; | |
| if (padding === 'same' || padding === 'valid') return padding; | |
| throw new Error(`Layer "${layer.type}" has invalid padding "${String(padding)}"`); | |
| } | |
| function getFlatDim(out: tf.Tensor): number { | |
| const [, h, w, c] = out.shape; | |
| if ( | |
| typeof h !== 'number' || | |
| typeof w !== 'number' || | |
| typeof c !== 'number' | |
| ) { | |
| throw new Error('Cannot flatten tensor with unknown shape'); | |
| } | |
| return h * w * c; | |
| } | |
| export class Cnn { | |
| architecture: LayerConfig[]; | |
| inChannels: number; | |
| weights: tf.Variable[]; | |
| constructor(architecture: string, inChannels: number) { | |
| this.architecture = parseArchitecture(architecture); | |
| this.inChannels = inChannels; | |
| this.weights = this.initWeights(); | |
| } | |
| initWeights(): tf.Variable[] { | |
| const weights: tf.Variable[] = []; | |
| let inChannels = this.inChannels; | |
| for (const layer of this.architecture) { | |
| if (layer.type === 'conv2d') { | |
| const kernel = getNumber(layer, 'kernel'); | |
| const filters = getNumber(layer, 'filters'); | |
| const shape: [number, number, number, number] = [ | |
| kernel, | |
| kernel, | |
| inChannels, | |
| filters, | |
| ]; | |
| const layerWeights = tf.variable( | |
| tf.randomUniform( | |
| shape, | |
| -Math.sqrt(1 / (kernel * kernel * inChannels)), | |
| Math.sqrt(1 / (kernel * kernel * inChannels)), | |
| ), | |
| ); | |
| weights.push(layerWeights); | |
| layer.weights = layerWeights; | |
| inChannels = filters; | |
| } else if (layer.type === 'dense') { | |
| layer.weights = null; | |
| layer.biases = null; | |
| } | |
| } | |
| return weights; | |
| } | |
| dispose(): void { | |
| for (const layer of this.architecture) { | |
| if (layer.type === 'conv2d') { | |
| getVariable(layer, 'weights').dispose(); | |
| } else if (layer.type === 'dense') { | |
| const weights = layer.weights; | |
| const biases = layer.biases; | |
| if (weights && typeof weights === 'object' && 'dispose' in weights) { | |
| (weights as tf.Variable).dispose(); | |
| } | |
| if (biases && typeof biases === 'object' && 'dispose' in biases) { | |
| (biases as tf.Variable).dispose(); | |
| } | |
| } | |
| } | |
| } | |
| forward(x: tf.Tensor4D): tf.Tensor { | |
| let out: tf.Tensor = x; | |
| for (let i = 0; i < this.architecture.length; i += 1) { | |
| const layer = this.architecture[i]; | |
| switch (layer.type) { | |
| case 'conv2d': { | |
| const layerWeights = getVariable(layer, 'weights'); | |
| const stride = getNumber(layer, 'stride'); | |
| const padding = getPadding(layer); | |
| out = tf.conv2d( | |
| out as tf.Tensor4D, | |
| layerWeights as tf.Tensor4D, | |
| stride, | |
| padding, | |
| ); | |
| if (layer.activationType === 'relu') { | |
| out = out.relu(); | |
| } | |
| break; | |
| } | |
| case 'maxpool': { | |
| const size = getNumber(layer, 'size'); | |
| const stride = getNumber(layer, 'stride'); | |
| out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0); | |
| break; | |
| } | |
| case 'flatten': { | |
| const flatDim = getFlatDim(out); | |
| out = out.reshape([-1, flatDim]); | |
| const next = this.architecture[i + 1]; | |
| if (next?.type === 'dense' && next.weights === null) { | |
| const units = getNumber(next, 'units'); | |
| next.weights = tf.variable( | |
| tf.randomUniform( | |
| [flatDim, units], | |
| -Math.sqrt(1 / flatDim), | |
| Math.sqrt(1 / flatDim), | |
| ), | |
| ); | |
| next.biases = tf.variable(tf.zeros([units])); | |
| } | |
| break; | |
| } | |
| case 'dense': { | |
| const denseWeights = getVariable(layer, 'weights'); | |
| const denseBiases = getVariable(layer, 'biases'); | |
| out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add( | |
| denseBiases as tf.Tensor1D, | |
| ); | |
| if (layer.activationType === 'relu') { | |
| out = out.relu(); | |
| } | |
| const next = this.architecture[i + 1]; | |
| if (next?.type === 'dense' && next.weights === null) { | |
| const nextUnits = getNumber(next, 'units'); | |
| const currentUnits = getNumber(layer, 'units'); | |
| next.weights = tf.variable( | |
| tf.randomUniform( | |
| [currentUnits, nextUnits], | |
| -Math.sqrt(1 / currentUnits), | |
| Math.sqrt(1 / currentUnits), | |
| ), | |
| ); | |
| next.biases = tf.variable(tf.zeros([nextUnits])); | |
| } | |
| break; | |
| } | |
| default: | |
| break; | |
| } | |
| } | |
| return out; | |
| } | |
| forwardWithInfo(x: tf.Tensor4D): { output: tf.Tensor; info: RunInfo[] } { | |
| let out: tf.Tensor = x; | |
| const info: RunInfo[] = []; | |
| info.push({ | |
| type: 'input', | |
| output: out.dataSync(), | |
| shape: x.shape, | |
| }); | |
| for (let i = 0; i < this.architecture.length; i += 1) { | |
| const layer = this.architecture[i]; | |
| switch (layer.type) { | |
| case 'conv2d': { | |
| const layerWeights = getVariable(layer, 'weights'); | |
| const stride = getNumber(layer, 'stride'); | |
| const padding = getPadding(layer); | |
| out = tf.conv2d( | |
| out as tf.Tensor4D, | |
| layerWeights as tf.Tensor4D, | |
| stride, | |
| padding, | |
| ); | |
| if (layer.activationType === 'relu') { | |
| out = out.relu(); | |
| } | |
| info.push({ | |
| type: 'conv2d', | |
| output: out.dataSync(), | |
| kernels: layerWeights.dataSync(), | |
| outputShape: out.shape, | |
| kernelShape: layerWeights.shape, | |
| stride, | |
| padding, | |
| activationType: layer.activationType, | |
| }); | |
| break; | |
| } | |
| case 'maxpool': { | |
| const size = getNumber(layer, 'size'); | |
| const stride = getNumber(layer, 'stride'); | |
| out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0); | |
| info.push({ | |
| type: 'maxpool', | |
| output: out.dataSync(), | |
| shape: out.shape, | |
| size, | |
| stride, | |
| }); | |
| break; | |
| } | |
| case 'flatten': { | |
| const flatDim = getFlatDim(out); | |
| out = out.reshape([-1, flatDim]); | |
| info.push({ | |
| type: 'flatten', | |
| output: out.dataSync(), | |
| shape: out.shape, | |
| }); | |
| const next = this.architecture[i + 1]; | |
| if (next?.type === 'dense' && next.weights === null) { | |
| const units = getNumber(next, 'units'); | |
| next.weights = tf.variable( | |
| tf.randomUniform( | |
| [flatDim, units], | |
| -Math.sqrt(1 / flatDim), | |
| Math.sqrt(1 / flatDim), | |
| ), | |
| ); | |
| next.biases = tf.variable(tf.zeros([units])); | |
| } | |
| break; | |
| } | |
| case 'dense': { | |
| const denseWeights = getVariable(layer, 'weights'); | |
| const denseBiases = getVariable(layer, 'biases'); | |
| out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add( | |
| denseBiases as tf.Tensor1D, | |
| ); | |
| if (layer.activationType === 'relu') { | |
| out = out.relu(); | |
| } | |
| const next = this.architecture[i + 1]; | |
| if (next?.type === 'dense' && next.weights === null) { | |
| const nextUnits = getNumber(next, 'units'); | |
| const currentUnits = getNumber(layer, 'units'); | |
| next.weights = tf.variable( | |
| tf.randomUniform( | |
| [currentUnits, nextUnits], | |
| -Math.sqrt(1 / currentUnits), | |
| Math.sqrt(1 / currentUnits), | |
| ), | |
| ); | |
| next.biases = tf.variable(tf.zeros([nextUnits])); | |
| } | |
| info.push({ | |
| type: 'dense', | |
| output: out.dataSync(), | |
| weights: denseWeights.dataSync(), | |
| biases: denseBiases.dataSync(), | |
| outputShape: out.shape, | |
| weightShape: denseWeights.shape, | |
| biasShape: denseBiases.shape, | |
| inputUnits: denseWeights.shape[0], | |
| outputUnits: getNumber(layer, 'units'), | |
| outputSize: getNumber(layer, 'units'), | |
| inputSize: denseWeights.shape[0], | |
| activationType: layer.activationType, | |
| }); | |
| break; | |
| } | |
| default: | |
| break; | |
| } | |
| } | |
| return { output: out, info }; | |
| } | |
| } | |
| export async function train( | |
| data: TrainingData, | |
| model: Cnn, | |
| optimizer: tf.Optimizer, | |
| batchSize: number, | |
| epochs: number, | |
| controller: TrainController, | |
| onBatchEnd: BatchEndCallback | null = null, | |
| ): Promise<void> { | |
| const numBatches = Math.floor(data.trainSize / batchSize); | |
| for (let epoch = 0; epoch < epochs; ++epoch) { | |
| for (let b = 0; b < numBatches; ++b) { | |
| if (controller.stopRequested) { | |
| console.log('Training stopped'); | |
| return; | |
| } | |
| while (controller.isPaused) { | |
| await tf.nextFrame(); | |
| } | |
| const cost = optimizer.minimize(() => { | |
| const batch = data.nextTrainBatch(batchSize); | |
| const xs = batch.xs.reshape([ | |
| batchSize, | |
| data.imageSize, | |
| data.imageSize, | |
| data.numInputChannels, | |
| ]) as tf.Tensor4D; | |
| const preds = model.forward(xs); | |
| return tf.losses.softmaxCrossEntropy(batch.labels, preds).mean(); | |
| }, true); | |
| if (!cost) { | |
| throw new Error('Optimizer did not return a loss tensor'); | |
| } | |
| const lossVal = (await cost.data())[0]; | |
| cost.dispose(); | |
| const sample = data.getTestSample(controller.sampleIndex); | |
| const { output, info } = model.forwardWithInfo( | |
| sample.xs.reshape([ | |
| 1, | |
| data.imageSize, | |
| data.imageSize, | |
| data.numInputChannels, | |
| ]) as tf.Tensor4D, | |
| ); | |
| const probs = tf.tidy(() => tf.softmax(output)); | |
| info.push({ | |
| type: 'output', | |
| output: probs.dataSync(), | |
| shape: probs.shape, | |
| }); | |
| if (controller.stopRequested) { | |
| console.log('Training stopped'); | |
| probs.dispose(); | |
| return; | |
| } | |
| if (onBatchEnd) { | |
| await onBatchEnd(epoch, b, lossVal, info); | |
| } | |
| probs.dispose(); | |
| await tf.nextFrame(); | |
| } | |
| } | |
| console.log('Training complete'); | |
| } | |