miniai-1m / train.js
PyxiLabs's picture
Create train.js
40f71b6 verified
const fs = require('fs');
const chalk = require('chalk');
const logSymbols = require('log-symbols');
const TRAINING_CONFIG = {
datasetPath: '3.json',
epochs: 10,
learningRate: 0.015,
maxSamples: 20000,
contextWindow: 25,
embedDim: 32,
hiddenDim: 64,
numLayers: 2,
weightsFile: 'tinychat_weights.json',
tokenizerFile: 'tinychat_tokenizer.json'
};
function loadTinyChatDataset(filename) {
console.log(logSymbols.info, chalk.blue(`Loading dataset from ${chalk.bold(filename)}...`));
const rawData = JSON.parse(fs.readFileSync(filename, 'utf8'));
const conversations = [];
rawData.forEach((conv) => {
const turns = conv.split('<end>')
.map(t => t.replace(/<[AB]>/g, '').trim())
.filter(t => t.length > 0);
conversations.push(...turns);
});
console.log(logSymbols.success, chalk.green(`Loaded ${conversations.length} conversation turns`));
console.log(chalk.dim(`📝 Sample: "${conversations[0].substring(0, 60)}..."`));
return conversations;
}
class Tokenizer {
constructor() {
this.vocab = { "<pad>": 0, "<unk>": 1, "<eos>": 2 };
this.reverseVocab = {};
this.vocabSize = 3;
}
buildVocab(texts, minFreq = 2) {
const wordCounts = {};
texts.forEach(text => {
text.toLowerCase().split(/\s+/).forEach(word => {
if (word) wordCounts[word] = (wordCounts[word] || 0) + 1;
});
});
let idx = this.vocabSize;
for (const [word, count] of Object.entries(wordCounts)) {
if (count >= minFreq) {
this.vocab[word] = idx++;
}
}
this.vocabSize = idx;
this.reverseVocab = Object.fromEntries(Object.entries(this.vocab).map(([k, v]) => [v, k]));
console.log(logSymbols.success, chalk.green(`Vocabulary built: ${this.vocabSize} tokens`));
console.log(chalk.dim(`📖 Sample vocab: ${Object.keys(this.vocab).slice(0, 15).join(", ")}`));
}
encode(text) {
return text.toLowerCase().split(/\s+/).filter(w => w).map(w => this.vocab[w] ?? this.vocab["<unk>"]);
}
decode(tokens) {
return tokens.map(t => this.reverseVocab[t] || "<unk>").join(" ");
}
}
class MathUtils {
static matmul(a, b) {
const m = a.length, n = a[0].length, p = b[0].length;
const result = Array(m).fill(0).map(() => Array(p).fill(0));
for (let i = 0; i < m; i++) {
for (let j = 0; j < p; j++) {
for (let k = 0; k < n; k++) {
result[i][j] += a[i][k] * b[k][j];
}
}
}
return result;
}
static matmul(a, b) {
const m = a.length, n = a[0].length, p = b[0].length;
const result = Array(m).fill(0).map(() => Array(p).fill(0));
for (let i = 0; i < m; i++) {
for (let j = 0; j < p; j++) {
for (let k = 0; k < n; k++) {
result[i][j] += a[i][k] * b[k][j];
}
}
}
return result;
}
static vecMatmul(vec, mat) {
const m = mat[0].length, n = vec.length;
const result = Array(m).fill(0);
for (let j = 0; j < m; j++) {
for (let i = 0; i < n; i++) {
result[j] += vec[i] * mat[i][j];
}
}
return result;
}
static outerProduct(vecA, vecB) {
const result = Array(vecA.length).fill(0).map(() => Array(vecB.length).fill(0));
for (let i = 0; i < vecA.length; i++) {
for (let j = 0; j < vecB.length; j++) {
result[i][j] = vecA[i] * vecB[j];
}
}
return result;
}
static transpose(matrix) {
if (!matrix || !matrix[0]) return [];
const rows = matrix.length;
const cols = matrix[0].length;
const result = Array(cols).fill(0).map(() => Array(rows).fill(0));
for (let i = 0; i < rows; i++) {
for (let j = 0; j < cols; j++) {
result[j][i] = matrix[i][j];
}
}
return result;
}
static add(a, b) { return a.map((val, i) => val + b[i]); }
static subtract(a, b) { return a.map((val, i) => val - b[i]); }
static scale(vec, s) { return vec.map(v => v * s); }
static relu(x) { return x.map(v => Math.max(0, v)); }
static reluDerivative(x) { return x.map(v => v > 0 ? 1 : 0); }
static softmax(logits) {
const maxLogit = Math.max(...logits);
const exp = logits.map(x => Math.exp(x - maxLogit));
const sum = exp.reduce((a, b) => a + b, 0);
return exp.map(x => x / sum);
}
static layerNorm(x) {
const mean = x.reduce((a, b) => a + b, 0) / x.length;
const variance = x.reduce((a, b) => a + (b - mean) ** 2, 0) / x.length;
const std = Math.sqrt(variance + 1e-5);
return x.map(val => (val - mean) / std);
}
static crossEntropy(probs, targetIdx) {
return -Math.log(Math.max(probs[targetIdx], 1e-10));
}
}
class MiniTransformer {
constructor(vocabSize, embedDim, hiddenDim, numLayers) {
this.vocabSize = vocabSize;
this.embedDim = embedDim;
this.hiddenDim = hiddenDim;
this.numLayers = numLayers;
this.embedding = this.randomMatrix(vocabSize, embedDim);
this.layers = Array(numLayers).fill(0).map(() => ({
attention: {
wq: this.randomMatrix(embedDim, embedDim),
wk: this.randomMatrix(embedDim, embedDim),
wv: this.randomMatrix(embedDim, embedDim),
wo: this.randomMatrix(embedDim, embedDim)
},
mlp: {
w1: this.randomMatrix(embedDim, hiddenDim),
b1: Array(hiddenDim).fill(0),
w2: this.randomMatrix(hiddenDim, embedDim),
b2: Array(embedDim).fill(0)
}
}));
this.outputWeights = this.randomMatrix(embedDim, vocabSize);
}
randomMatrix(rows, cols) {
const scale = Math.sqrt(2.0 / rows);
return Array(rows).fill(0).map(() =>
Array(cols).fill(0).map(() => (Math.random() - 0.5) * 2 * scale)
);
}
forward(tokens) {
this.cache = { tokens, layers: [] };
let x_sequence = tokens.map(t => this.embedding[t]);
for (const layer of this.layers) {
const layerCache = {};
const last_x = x_sequence[x_sequence.length - 1];
const q = MathUtils.vecMatmul(last_x, layer.attention.wq);
const k = MathUtils.vecMatmul(last_x, layer.attention.wk);
const v = MathUtils.vecMatmul(last_x, layer.attention.wv);
const attn_out = MathUtils.vecMatmul(v, layer.attention.wo);
let x = MathUtils.add(last_x, attn_out);
x = MathUtils.layerNorm(x);
layerCache.postAttn = [...x];
const mlp_hidden = MathUtils.add(MathUtils.vecMatmul(x, layer.mlp.w1), layer.mlp.b1);
layerCache.preRelu = [...mlp_hidden];
const mlp_activated = MathUtils.relu(mlp_hidden);
layerCache.postRelu = [...mlp_activated];
const mlp_out = MathUtils.add(MathUtils.vecMatmul(mlp_activated, layer.mlp.w2), layer.mlp.b2);
x = MathUtils.add(x, mlp_out);
x = MathUtils.layerNorm(x);
x_sequence[x_sequence.length - 1] = x;
this.cache.layers.push(layerCache);
}
const finalHidden = x_sequence[x_sequence.length - 1];
const logits = MathUtils.vecMatmul(finalHidden, this.outputWeights);
const probs = MathUtils.softmax(logits);
this.cache.finalHidden = finalHidden;
this.cache.probs = probs;
return probs;
}
backward(targetIdx, lr) {
let dLogits = [...this.cache.probs];
dLogits[targetIdx] -= 1;
const outputWeightsT = MathUtils.transpose(this.outputWeights);
const dFinalHidden = MathUtils.vecMatmul(dLogits, outputWeightsT);
const dOutputWeights = MathUtils.outerProduct(this.cache.finalHidden, dLogits);
for (let i = 0; i < this.embedDim; i++) {
for (let j = 0; j < this.vocabSize; j++) {
this.outputWeights[i][j] -= lr * dOutputWeights[i][j];
}
}
let dCurrent = dFinalHidden;
for (let l = this.numLayers - 1; l >= 0; l--) {
const layer = this.layers[l];
const cache = this.cache.layers[l];
const dMLP_out = dCurrent;
const w2_T = MathUtils.transpose(layer.mlp.w2);
const dHidden_activated = MathUtils.vecMatmul(dMLP_out, w2_T);
const dW2 = MathUtils.outerProduct(cache.postRelu, dMLP_out);
const dB2 = dMLP_out;
const dHidden_preRelu = dHidden_activated.map((g, i) => g * (cache.preRelu[i] > 0 ? 1 : 0));
const w1_T = MathUtils.transpose(layer.mlp.w1);
dCurrent = MathUtils.vecMatmul(dHidden_preRelu, w1_T);
const dW1 = MathUtils.outerProduct(cache.postAttn, dHidden_preRelu);
const dB1 = dHidden_preRelu;
for(let i=0; i<this.hiddenDim; i++) for(let j=0; j<this.embedDim; j++) layer.mlp.w2[i][j] -= lr * dW2[i][j];
for(let j=0; j<this.embedDim; j++) layer.mlp.b2[j] -= lr * dB2[j];
for(let i=0; i<this.embedDim; i++) for(let j=0; j<this.hiddenDim; j++) layer.mlp.w1[i][j] -= lr * dW1[i][j];
for(let j=0; j<this.hiddenDim; j++) layer.mlp.b1[j] -= lr * dB1[j];
}
const dEmbedding = dCurrent;
const lastTokenId = this.cache.tokens[this.cache.tokens.length - 1];
for (let i = 0; i < this.embedDim; i++) {
this.embedding[lastTokenId][i] -= lr * dEmbedding[i];
}
}
saveWeights(filename) {
const weights = {
vocabSize: this.vocabSize, embedDim: this.embedDim, hiddenDim: this.hiddenDim,
numLayers: this.numLayers, embedding: this.embedding, layers: this.layers,
outputWeights: this.outputWeights
};
fs.writeFileSync(filename, JSON.stringify(weights));
console.log(logSymbols.success, chalk.green(`Weights saved to ${chalk.bold(filename)}`));
}
}
function train(config) {
console.log("\n" + chalk.bold.yellow("🚀 Starting Transformer Training...") + "\n" + "=".repeat(60));
const texts = loadTinyChatDataset(config.datasetPath);
const tokenizer = new Tokenizer();
tokenizer.buildVocab(texts, 2);
console.log(logSymbols.info, chalk.blue(`\nCreating training data (up to ${config.maxSamples} samples)...`));
const trainingPairs = [];
for (const text of texts) {
const tokens = tokenizer.encode(text);
for (let i = 0; i < tokens.length - 1; i++) {
trainingPairs.push({
input: tokens.slice(Math.max(0, i - config.contextWindow + 1), i + 1),
target: tokens[i + 1]
});
if (trainingPairs.length >= config.maxSamples) break;
}
if (trainingPairs.length >= config.maxSamples) break;
}
console.log(logSymbols.success, chalk.green(`Created ${trainingPairs.length} training examples.`));
const model = new MiniTransformer(tokenizer.vocabSize, config.embedDim, config.hiddenDim, config.numLayers);
console.log(chalk.cyan.bold("\n🧠 Model Architecture:"));
console.log(chalk.cyan(` - Vocab Size: ${tokenizer.vocabSize}`));
console.log(chalk.cyan(` - Embedding Dim: ${model.embedDim}`));
console.log(chalk.cyan(` - Hidden Dim: ${model.hiddenDim}`));
console.log(chalk.cyan(` - Layers: ${model.numLayers}\n`));
console.log("=".repeat(60) + chalk.bold.yellow("\n📈 Training Progress:\n"));
for (let epoch = 0; epoch < config.epochs; epoch++) {
let totalLoss = 0;
const shuffled = trainingPairs.sort(() => Math.random() - 0.5);
for (let i = 0; i < shuffled.length; i++) {
const { input, target } = shuffled[i];
const probs = model.forward(input);
totalLoss += MathUtils.crossEntropy(probs, target);
model.backward(target, config.learningRate);
if ((i + 1) % 1000 === 0) {
process.stdout.write(chalk.dim(`\rEpoch ${epoch + 1}/${config.epochs} | Batch ${i+1}/${shuffled.length}`));
}
}
const avgLoss = totalLoss / trainingPairs.length;
process.stdout.clearLine(0);
process.stdout.cursorTo(0);
console.log(
chalk.magenta(`Epoch ${(epoch + 1).toString().padStart(2)}/${config.epochs}`) +
chalk.gray(` | `) +
chalk.green(`Loss: ${avgLoss.toFixed(4)}`)
);
}
console.log("=".repeat(60));
model.saveWeights(config.weightsFile);
fs.writeFileSync(config.tokenizerFile, JSON.stringify({ vocab: tokenizer.vocab, vocabSize: tokenizer.vocabSize }, null, 2));
console.log(logSymbols.success, chalk.green(`Tokenizer saved to ${chalk.bold(config.tokenizerFile)}`));
console.log("\n" + "=".repeat(60) + chalk.bold.yellow("\n🎯 TESTING GENERATION\n"));
const testPrompts = ["What do you", "I like to", "Do you have"];
testPrompts.forEach(prompt => {
let tokens = tokenizer.encode(prompt);
let generated = [];
for (let i = 0; i < 10; i++) {
const probs = model.forward(tokens.slice(-config.contextWindow));
const nextToken = probs.indexOf(Math.max(...probs));
if (nextToken === tokenizer.vocab["<eos>"]) break;
generated.push(nextToken);
tokens.push(nextToken);
}
console.log(chalk.blue(`Prompt: "${prompt}"`));
console.log(chalk.green(`Output: "${tokenizer.decode([...tokenizer.encode(prompt), ...generated])}"\n`));
});
console.log("=".repeat(60) + "\n" + logSymbols.success, chalk.bold.bgGreen.black(" Training complete! "));
}
if (!fs.existsSync(TRAINING_CONFIG.datasetPath)) {
console.error(logSymbols.error, chalk.red(`Error: Dataset file not found at ${chalk.bold(TRAINING_CONFIG.datasetPath)}!`));
process.exit(1);
}
train(TRAINING_CONFIG);