cnn_visualizer / src /mnist.d.ts
joel-woodfield's picture
Add refactored version of cnn visualizer
1f2f6bf
import * as tf from "@tensorflow/tfjs";
export interface BatchData {
xs: tf.Tensor2D;
labels: tf.Tensor2D;
}
export interface TestSample {
xs: tf.Tensor2D;
labels: tf.Tensor2D;
}
export class MnistData {
shuffledTrainIndex: number;
shuffledTestIndex: number;
numClasses: number;
numInputChannels: number;
trainSize: number;
testSize: number;
imageSize: number;
datasetImages: Float32Array;
datasetLabels: Uint8Array;
trainIndices: Uint32Array;
testIndices: Uint32Array;
trainImages: Float32Array;
testImages: Float32Array;
trainLabels: Uint8Array;
testLabels: Uint8Array;
load(): Promise<void>;
nextTrainBatch(batchSize: number): BatchData;
nextTestBatch(batchSize: number): BatchData;
nextBatch(
batchSize: number,
data: [Float32Array, Uint8Array],
index: () => number,
): BatchData;
getTestSample(index: number): TestSample;
}