| const fs = require('fs'); | |
| const readline = require('readline'); | |
| class MathUtils { | |
| static vecMatmul(vec, mat) { | |
| const m = mat[0].length; | |
| const result = Array(m).fill(0); | |
| for (let j = 0; j < m; j++) { | |
| for (let i = 0; i < vec.length; i++) { | |
| result[j] += vec[i] * mat[i][j]; | |
| } | |
| } | |
| return result; | |
| } | |
| static add(a, b) { | |
| return a.map((val, i) => val + b[i]); | |
| } | |
| static relu(x) { | |
| return x.map(val => Math.max(0, val)); | |
| } | |
| static softmax(logits) { | |
| const maxLogit = Math.max(...logits); | |
| const expValues = logits.map(x => Math.exp(x - maxLogit)); | |
| const sumExp = expValues.reduce((a, b) => a + b, 0); | |
| return expValues.map(x => x / sumExp); | |
| } | |
| static layerNorm(x) { | |
| const mean = x.reduce((a, b) => a + b, 0) / x.length; | |
| const variance = x.reduce((a, b) => a + (b - mean) ** 2, 0) / x.length; | |
| const std = Math.sqrt(variance + 1e-5); | |
| return x.map(val => (val - mean) / std); | |
| } | |
| } | |
| class Tokenizer { | |
| constructor(vocab) { | |
| this.vocab = vocab; | |
| this.reverseVocab = Object.fromEntries( | |
| Object.entries(vocab).map(([k, v]) => [v, k]) | |
| ); | |
| this.vocabSize = Object.keys(vocab).length; | |
| } | |
| encode(text) { | |
| return text.toLowerCase() | |
| .split(/\s+/) | |
| .filter(w => w.length > 0) | |
| .map(w => this.vocab[w] ?? this.vocab["<unk>"]); | |
| } | |
| decode(tokens) { | |
| return tokens.map(t => this.reverseVocab[t] || "<unk>").join(" "); | |
| } | |
| } | |
| class MiniTransformer { | |
| constructor(weights) { | |
| this.vocabSize = weights.vocabSize; | |
| this.embedDim = weights.embedDim; | |
| this.hiddenDim = weights.hiddenDim; | |
| this.numLayers = weights.numLayers; | |
| this.embedding = weights.embedding; | |
| this.layers = weights.layers; | |
| this.outputWeights = weights.outputWeights; | |
| } | |
| embed(tokenId) { | |
| return [...this.embedding[tokenId]]; | |
| } | |
| forward(tokens) { | |
| const embeddings = tokens.map(t => this.embed(t)); | |
| let x = embeddings[embeddings.length - 1]; | |
| for (const layer of this.layers) { | |
| const attnOut = this.attention(x, layer.attention); | |
| x = MathUtils.add(x, attnOut); | |
| x = MathUtils.layerNorm(x); | |
| const mlpOut = this.mlp(x, layer.mlp); | |
| x = MathUtils.add(x, mlpOut); | |
| x = MathUtils.layerNorm(x); | |
| } | |
| const logits = MathUtils.vecMatmul(x, this.outputWeights); | |
| return MathUtils.softmax(logits); | |
| } | |
| attention(x, attnWeights) { | |
| const q = MathUtils.vecMatmul(x, attnWeights.wq); | |
| const k = MathUtils.vecMatmul(x, attnWeights.wk); | |
| const v = MathUtils.vecMatmul(x, attnWeights.wv); | |
| const score = q.reduce((sum, val, i) => sum + val * k[i], 0); | |
| const attn = 1.0; | |
| const context = v.map(val => val * attn); | |
| return MathUtils.vecMatmul(context, attnWeights.wo); | |
| } | |
| mlp(x, mlpWeights) { | |
| let hidden = MathUtils.vecMatmul(x, mlpWeights.w1); | |
| hidden = MathUtils.add(hidden, mlpWeights.b1); | |
| hidden = MathUtils.relu(hidden); | |
| let output = MathUtils.vecMatmul(hidden, mlpWeights.w2); | |
| output = MathUtils.add(output, mlpWeights.b2); | |
| return output; | |
| } | |
| generate(tokens, maxTokens = 20, temperature = 0.8, topK = 10, repetitionPenalty = 1.2) { | |
| const generated = [...tokens]; | |
| for (let i = 0; i < maxTokens; i++) { | |
| const contextTokens = generated.slice(-5); | |
| let probs = this.forward(contextTokens); | |
| for (let j = 0; j < probs.length; j++) { | |
| if (generated.includes(j)) probs[j] /= repetitionPenalty; | |
| } | |
| const entropy = -probs.reduce((a, p) => a + (p > 0 ? p * Math.log(p) : 0), 0); | |
| const adaptiveTemp = Math.max(0.5, Math.min(1.2, temperature * (entropy + 0.5))); | |
| probs = probs.map(p => Math.pow(p, 1 / adaptiveTemp)); | |
| const sum = probs.reduce((a, b) => a + b, 0); | |
| probs = probs.map(p => p / sum); | |
| const topIndices = probs | |
| .map((p, i) => ({ prob: p, index: i })) | |
| .sort((a, b) => b.prob - a.prob) | |
| .slice(0, topK); | |
| const totalProb = topIndices.reduce((a, b) => a + b.prob, 0); | |
| const topProbs = topIndices.map(item => ({ | |
| index: item.index, | |
| prob: item.prob / totalProb | |
| })); | |
| const nextToken = this.sampleFromProbs(topProbs); | |
| generated.push(nextToken); | |
| if (nextToken === 2 || nextToken === 0) break; | |
| } | |
| return generated; | |
| } | |
| sampleFromProbs(topProbs) { | |
| const rand = Math.random(); | |
| let cumSum = 0; | |
| for (const item of topProbs) { | |
| cumSum += item.prob; | |
| if (rand < cumSum) return item.index; | |
| } | |
| return topProbs[topProbs.length - 1].index; | |
| } | |
| } | |
| async function interactiveChat() { | |
| console.log("\n🤖 TinyChat Model - Interactive Chat"); | |
| console.log("=" .repeat(60)); | |
| console.log("\n📖 Loading tokenizer..."); | |
| const tokenizerData = JSON.parse(fs.readFileSync('tinychat_tokenizer.json', 'utf8')); | |
| const tokenizer = new Tokenizer(tokenizerData.vocab); | |
| console.log(`✅ Vocabulary: ${tokenizer.vocabSize} tokens`); | |
| console.log("🧠 Loading model weights..."); | |
| const weights = JSON.parse(fs.readFileSync('tinychat_weights.json', 'utf8')); | |
| const model = new MiniTransformer(weights); | |
| console.log(`✅ Model loaded (${weights.embedDim}D, ${weights.numLayers} layers)`); | |
| console.log("\n" + "=" .repeat(60)); | |
| console.log("💬 Chat with your AI! (type 'quit' to exit)"); | |
| console.log("💡 Tips:"); | |
| console.log(" - Try prompts from your training data"); | |
| console.log(" - Use 2-4 words for best results"); | |
| console.log(" - Model may repeat or produce gibberish (it's small!)"); | |
| console.log("=" .repeat(60) + "\n"); | |
| const rl = readline.createInterface({ | |
| input: process.stdin, | |
| output: process.stdout | |
| }); | |
| const askQuestion = () => { | |
| rl.question('You: ', (input) => { | |
| const prompt = input.trim(); | |
| if (prompt.toLowerCase() === 'quit' || prompt.toLowerCase() === 'exit') { | |
| console.log("\n👋 Goodbye!\n"); | |
| rl.close(); | |
| return; | |
| } | |
| if (prompt.length === 0) { | |
| askQuestion(); | |
| return; | |
| } | |
| const tokens = tokenizer.encode(prompt); | |
| if (tokens.length === 0) { | |
| console.log("Bot: [Unable to understand - try different words]\n"); | |
| askQuestion(); | |
| return; | |
| } | |
| const generated = model.generate( | |
| tokens, | |
| maxTokens = 8, | |
| temperature = 0.3, | |
| topK = 3 | |
| ); | |
| const response = tokenizer.decode(generated); | |
| console.log(`Bot: ${response}\n`); | |
| askQuestion(); | |
| }); | |
| }; | |
| askQuestion(); | |
| } | |
| function main() { | |
| interactiveChat(); | |
| } | |
| main(); |