| const fs = require('fs'); | |
| const chalk = require('chalk'); | |
| const logSymbols = require('log-symbols'); | |
| const TRAINING_CONFIG = { | |
| datasetPath: '3.json', | |
| epochs: 10, | |
| learningRate: 0.015, | |
| maxSamples: 20000, | |
| contextWindow: 25, | |
| embedDim: 32, | |
| hiddenDim: 64, | |
| numLayers: 2, | |
| weightsFile: 'tinychat_weights.json', | |
| tokenizerFile: 'tinychat_tokenizer.json' | |
| }; | |
| function loadTinyChatDataset(filename) { | |
| console.log(logSymbols.info, chalk.blue(`Loading dataset from ${chalk.bold(filename)}...`)); | |
| const rawData = JSON.parse(fs.readFileSync(filename, 'utf8')); | |
| const conversations = []; | |
| rawData.forEach((conv) => { | |
| const turns = conv.split('<end>') | |
| .map(t => t.replace(/<[AB]>/g, '').trim()) | |
| .filter(t => t.length > 0); | |
| conversations.push(...turns); | |
| }); | |
| console.log(logSymbols.success, chalk.green(`Loaded ${conversations.length} conversation turns`)); | |
| console.log(chalk.dim(`📝 Sample: "${conversations[0].substring(0, 60)}..."`)); | |
| return conversations; | |
| } | |
| class Tokenizer { | |
| constructor() { | |
| this.vocab = { "<pad>": 0, "<unk>": 1, "<eos>": 2 }; | |
| this.reverseVocab = {}; | |
| this.vocabSize = 3; | |
| } | |
| buildVocab(texts, minFreq = 2) { | |
| const wordCounts = {}; | |
| texts.forEach(text => { | |
| text.toLowerCase().split(/\s+/).forEach(word => { | |
| if (word) wordCounts[word] = (wordCounts[word] || 0) + 1; | |
| }); | |
| }); | |
| let idx = this.vocabSize; | |
| for (const [word, count] of Object.entries(wordCounts)) { | |
| if (count >= minFreq) { | |
| this.vocab[word] = idx++; | |
| } | |
| } | |
| this.vocabSize = idx; | |
| this.reverseVocab = Object.fromEntries(Object.entries(this.vocab).map(([k, v]) => [v, k])); | |
| console.log(logSymbols.success, chalk.green(`Vocabulary built: ${this.vocabSize} tokens`)); | |
| console.log(chalk.dim(`📖 Sample vocab: ${Object.keys(this.vocab).slice(0, 15).join(", ")}`)); | |
| } | |
| encode(text) { | |
| return text.toLowerCase().split(/\s+/).filter(w => w).map(w => this.vocab[w] ?? this.vocab["<unk>"]); | |
| } | |
| decode(tokens) { | |
| return tokens.map(t => this.reverseVocab[t] || "<unk>").join(" "); | |
| } | |
| } | |
| class MathUtils { | |
| static matmul(a, b) { | |
| const m = a.length, n = a[0].length, p = b[0].length; | |
| const result = Array(m).fill(0).map(() => Array(p).fill(0)); | |
| for (let i = 0; i < m; i++) { | |
| for (let j = 0; j < p; j++) { | |
| for (let k = 0; k < n; k++) { | |
| result[i][j] += a[i][k] * b[k][j]; | |
| } | |
| } | |
| } | |
| return result; | |
| } | |
| static matmul(a, b) { | |
| const m = a.length, n = a[0].length, p = b[0].length; | |
| const result = Array(m).fill(0).map(() => Array(p).fill(0)); | |
| for (let i = 0; i < m; i++) { | |
| for (let j = 0; j < p; j++) { | |
| for (let k = 0; k < n; k++) { | |
| result[i][j] += a[i][k] * b[k][j]; | |
| } | |
| } | |
| } | |
| return result; | |
| } | |
| static vecMatmul(vec, mat) { | |
| const m = mat[0].length, n = vec.length; | |
| const result = Array(m).fill(0); | |
| for (let j = 0; j < m; j++) { | |
| for (let i = 0; i < n; i++) { | |
| result[j] += vec[i] * mat[i][j]; | |
| } | |
| } | |
| return result; | |
| } | |
| static outerProduct(vecA, vecB) { | |
| const result = Array(vecA.length).fill(0).map(() => Array(vecB.length).fill(0)); | |
| for (let i = 0; i < vecA.length; i++) { | |
| for (let j = 0; j < vecB.length; j++) { | |
| result[i][j] = vecA[i] * vecB[j]; | |
| } | |
| } | |
| return result; | |
| } | |
| static transpose(matrix) { | |
| if (!matrix || !matrix[0]) return []; | |
| const rows = matrix.length; | |
| const cols = matrix[0].length; | |
| const result = Array(cols).fill(0).map(() => Array(rows).fill(0)); | |
| for (let i = 0; i < rows; i++) { | |
| for (let j = 0; j < cols; j++) { | |
| result[j][i] = matrix[i][j]; | |
| } | |
| } | |
| return result; | |
| } | |
| static add(a, b) { return a.map((val, i) => val + b[i]); } | |
| static subtract(a, b) { return a.map((val, i) => val - b[i]); } | |
| static scale(vec, s) { return vec.map(v => v * s); } | |
| static relu(x) { return x.map(v => Math.max(0, v)); } | |
| static reluDerivative(x) { return x.map(v => v > 0 ? 1 : 0); } | |
| static softmax(logits) { | |
| const maxLogit = Math.max(...logits); | |
| const exp = logits.map(x => Math.exp(x - maxLogit)); | |
| const sum = exp.reduce((a, b) => a + b, 0); | |
| return exp.map(x => x / sum); | |
| } | |
| static layerNorm(x) { | |
| const mean = x.reduce((a, b) => a + b, 0) / x.length; | |
| const variance = x.reduce((a, b) => a + (b - mean) ** 2, 0) / x.length; | |
| const std = Math.sqrt(variance + 1e-5); | |
| return x.map(val => (val - mean) / std); | |
| } | |
| static crossEntropy(probs, targetIdx) { | |
| return -Math.log(Math.max(probs[targetIdx], 1e-10)); | |
| } | |
| } | |
| class MiniTransformer { | |
| constructor(vocabSize, embedDim, hiddenDim, numLayers) { | |
| this.vocabSize = vocabSize; | |
| this.embedDim = embedDim; | |
| this.hiddenDim = hiddenDim; | |
| this.numLayers = numLayers; | |
| this.embedding = this.randomMatrix(vocabSize, embedDim); | |
| this.layers = Array(numLayers).fill(0).map(() => ({ | |
| attention: { | |
| wq: this.randomMatrix(embedDim, embedDim), | |
| wk: this.randomMatrix(embedDim, embedDim), | |
| wv: this.randomMatrix(embedDim, embedDim), | |
| wo: this.randomMatrix(embedDim, embedDim) | |
| }, | |
| mlp: { | |
| w1: this.randomMatrix(embedDim, hiddenDim), | |
| b1: Array(hiddenDim).fill(0), | |
| w2: this.randomMatrix(hiddenDim, embedDim), | |
| b2: Array(embedDim).fill(0) | |
| } | |
| })); | |
| this.outputWeights = this.randomMatrix(embedDim, vocabSize); | |
| } | |
| randomMatrix(rows, cols) { | |
| const scale = Math.sqrt(2.0 / rows); | |
| return Array(rows).fill(0).map(() => | |
| Array(cols).fill(0).map(() => (Math.random() - 0.5) * 2 * scale) | |
| ); | |
| } | |
| forward(tokens) { | |
| this.cache = { tokens, layers: [] }; | |
| let x_sequence = tokens.map(t => this.embedding[t]); | |
| for (const layer of this.layers) { | |
| const layerCache = {}; | |
| const last_x = x_sequence[x_sequence.length - 1]; | |
| const q = MathUtils.vecMatmul(last_x, layer.attention.wq); | |
| const k = MathUtils.vecMatmul(last_x, layer.attention.wk); | |
| const v = MathUtils.vecMatmul(last_x, layer.attention.wv); | |
| const attn_out = MathUtils.vecMatmul(v, layer.attention.wo); | |
| let x = MathUtils.add(last_x, attn_out); | |
| x = MathUtils.layerNorm(x); | |
| layerCache.postAttn = [...x]; | |
| const mlp_hidden = MathUtils.add(MathUtils.vecMatmul(x, layer.mlp.w1), layer.mlp.b1); | |
| layerCache.preRelu = [...mlp_hidden]; | |
| const mlp_activated = MathUtils.relu(mlp_hidden); | |
| layerCache.postRelu = [...mlp_activated]; | |
| const mlp_out = MathUtils.add(MathUtils.vecMatmul(mlp_activated, layer.mlp.w2), layer.mlp.b2); | |
| x = MathUtils.add(x, mlp_out); | |
| x = MathUtils.layerNorm(x); | |
| x_sequence[x_sequence.length - 1] = x; | |
| this.cache.layers.push(layerCache); | |
| } | |
| const finalHidden = x_sequence[x_sequence.length - 1]; | |
| const logits = MathUtils.vecMatmul(finalHidden, this.outputWeights); | |
| const probs = MathUtils.softmax(logits); | |
| this.cache.finalHidden = finalHidden; | |
| this.cache.probs = probs; | |
| return probs; | |
| } | |
| backward(targetIdx, lr) { | |
| let dLogits = [...this.cache.probs]; | |
| dLogits[targetIdx] -= 1; | |
| const outputWeightsT = MathUtils.transpose(this.outputWeights); | |
| const dFinalHidden = MathUtils.vecMatmul(dLogits, outputWeightsT); | |
| const dOutputWeights = MathUtils.outerProduct(this.cache.finalHidden, dLogits); | |
| for (let i = 0; i < this.embedDim; i++) { | |
| for (let j = 0; j < this.vocabSize; j++) { | |
| this.outputWeights[i][j] -= lr * dOutputWeights[i][j]; | |
| } | |
| } | |
| let dCurrent = dFinalHidden; | |
| for (let l = this.numLayers - 1; l >= 0; l--) { | |
| const layer = this.layers[l]; | |
| const cache = this.cache.layers[l]; | |
| const dMLP_out = dCurrent; | |
| const w2_T = MathUtils.transpose(layer.mlp.w2); | |
| const dHidden_activated = MathUtils.vecMatmul(dMLP_out, w2_T); | |
| const dW2 = MathUtils.outerProduct(cache.postRelu, dMLP_out); | |
| const dB2 = dMLP_out; | |
| const dHidden_preRelu = dHidden_activated.map((g, i) => g * (cache.preRelu[i] > 0 ? 1 : 0)); | |
| const w1_T = MathUtils.transpose(layer.mlp.w1); | |
| dCurrent = MathUtils.vecMatmul(dHidden_preRelu, w1_T); | |
| const dW1 = MathUtils.outerProduct(cache.postAttn, dHidden_preRelu); | |
| const dB1 = dHidden_preRelu; | |
| for(let i=0; i<this.hiddenDim; i++) for(let j=0; j<this.embedDim; j++) layer.mlp.w2[i][j] -= lr * dW2[i][j]; | |
| for(let j=0; j<this.embedDim; j++) layer.mlp.b2[j] -= lr * dB2[j]; | |
| for(let i=0; i<this.embedDim; i++) for(let j=0; j<this.hiddenDim; j++) layer.mlp.w1[i][j] -= lr * dW1[i][j]; | |
| for(let j=0; j<this.hiddenDim; j++) layer.mlp.b1[j] -= lr * dB1[j]; | |
| } | |
| const dEmbedding = dCurrent; | |
| const lastTokenId = this.cache.tokens[this.cache.tokens.length - 1]; | |
| for (let i = 0; i < this.embedDim; i++) { | |
| this.embedding[lastTokenId][i] -= lr * dEmbedding[i]; | |
| } | |
| } | |
| saveWeights(filename) { | |
| const weights = { | |
| vocabSize: this.vocabSize, embedDim: this.embedDim, hiddenDim: this.hiddenDim, | |
| numLayers: this.numLayers, embedding: this.embedding, layers: this.layers, | |
| outputWeights: this.outputWeights | |
| }; | |
| fs.writeFileSync(filename, JSON.stringify(weights)); | |
| console.log(logSymbols.success, chalk.green(`Weights saved to ${chalk.bold(filename)}`)); | |
| } | |
| } | |
| function train(config) { | |
| console.log("\n" + chalk.bold.yellow("🚀 Starting Transformer Training...") + "\n" + "=".repeat(60)); | |
| const texts = loadTinyChatDataset(config.datasetPath); | |
| const tokenizer = new Tokenizer(); | |
| tokenizer.buildVocab(texts, 2); | |
| console.log(logSymbols.info, chalk.blue(`\nCreating training data (up to ${config.maxSamples} samples)...`)); | |
| const trainingPairs = []; | |
| for (const text of texts) { | |
| const tokens = tokenizer.encode(text); | |
| for (let i = 0; i < tokens.length - 1; i++) { | |
| trainingPairs.push({ | |
| input: tokens.slice(Math.max(0, i - config.contextWindow + 1), i + 1), | |
| target: tokens[i + 1] | |
| }); | |
| if (trainingPairs.length >= config.maxSamples) break; | |
| } | |
| if (trainingPairs.length >= config.maxSamples) break; | |
| } | |
| console.log(logSymbols.success, chalk.green(`Created ${trainingPairs.length} training examples.`)); | |
| const model = new MiniTransformer(tokenizer.vocabSize, config.embedDim, config.hiddenDim, config.numLayers); | |
| console.log(chalk.cyan.bold("\n🧠 Model Architecture:")); | |
| console.log(chalk.cyan(` - Vocab Size: ${tokenizer.vocabSize}`)); | |
| console.log(chalk.cyan(` - Embedding Dim: ${model.embedDim}`)); | |
| console.log(chalk.cyan(` - Hidden Dim: ${model.hiddenDim}`)); | |
| console.log(chalk.cyan(` - Layers: ${model.numLayers}\n`)); | |
| console.log("=".repeat(60) + chalk.bold.yellow("\n📈 Training Progress:\n")); | |
| for (let epoch = 0; epoch < config.epochs; epoch++) { | |
| let totalLoss = 0; | |
| const shuffled = trainingPairs.sort(() => Math.random() - 0.5); | |
| for (let i = 0; i < shuffled.length; i++) { | |
| const { input, target } = shuffled[i]; | |
| const probs = model.forward(input); | |
| totalLoss += MathUtils.crossEntropy(probs, target); | |
| model.backward(target, config.learningRate); | |
| if ((i + 1) % 1000 === 0) { | |
| process.stdout.write(chalk.dim(`\rEpoch ${epoch + 1}/${config.epochs} | Batch ${i+1}/${shuffled.length}`)); | |
| } | |
| } | |
| const avgLoss = totalLoss / trainingPairs.length; | |
| process.stdout.clearLine(0); | |
| process.stdout.cursorTo(0); | |
| console.log( | |
| chalk.magenta(`Epoch ${(epoch + 1).toString().padStart(2)}/${config.epochs}`) + | |
| chalk.gray(` | `) + | |
| chalk.green(`Loss: ${avgLoss.toFixed(4)}`) | |
| ); | |
| } | |
| console.log("=".repeat(60)); | |
| model.saveWeights(config.weightsFile); | |
| fs.writeFileSync(config.tokenizerFile, JSON.stringify({ vocab: tokenizer.vocab, vocabSize: tokenizer.vocabSize }, null, 2)); | |
| console.log(logSymbols.success, chalk.green(`Tokenizer saved to ${chalk.bold(config.tokenizerFile)}`)); | |
| console.log("\n" + "=".repeat(60) + chalk.bold.yellow("\n🎯 TESTING GENERATION\n")); | |
| const testPrompts = ["What do you", "I like to", "Do you have"]; | |
| testPrompts.forEach(prompt => { | |
| let tokens = tokenizer.encode(prompt); | |
| let generated = []; | |
| for (let i = 0; i < 10; i++) { | |
| const probs = model.forward(tokens.slice(-config.contextWindow)); | |
| const nextToken = probs.indexOf(Math.max(...probs)); | |
| if (nextToken === tokenizer.vocab["<eos>"]) break; | |
| generated.push(nextToken); | |
| tokens.push(nextToken); | |
| } | |
| console.log(chalk.blue(`Prompt: "${prompt}"`)); | |
| console.log(chalk.green(`Output: "${tokenizer.decode([...tokenizer.encode(prompt), ...generated])}"\n`)); | |
| }); | |
| console.log("=".repeat(60) + "\n" + logSymbols.success, chalk.bold.bgGreen.black(" Training complete! ")); | |
| } | |
| if (!fs.existsSync(TRAINING_CONFIG.datasetPath)) { | |
| console.error(logSymbols.error, chalk.red(`Error: Dataset file not found at ${chalk.bold(TRAINING_CONFIG.datasetPath)}!`)); | |
| process.exit(1); | |
| } | |
| train(TRAINING_CONFIG); |