miniai-1m / index.js
PyxiLabs's picture
Rename index (2).js to index.js
dcacd1f verified
const fs = require('fs');
const readline = require('readline');
class MathUtils {
static vecMatmul(vec, mat) {
const m = mat[0].length;
const result = Array(m).fill(0);
for (let j = 0; j < m; j++) {
for (let i = 0; i < vec.length; i++) {
result[j] += vec[i] * mat[i][j];
}
}
return result;
}
static add(a, b) {
return a.map((val, i) => val + b[i]);
}
static relu(x) {
return x.map(val => Math.max(0, val));
}
static softmax(logits) {
const maxLogit = Math.max(...logits);
const expValues = logits.map(x => Math.exp(x - maxLogit));
const sumExp = expValues.reduce((a, b) => a + b, 0);
return expValues.map(x => x / sumExp);
}
static layerNorm(x) {
const mean = x.reduce((a, b) => a + b, 0) / x.length;
const variance = x.reduce((a, b) => a + (b - mean) ** 2, 0) / x.length;
const std = Math.sqrt(variance + 1e-5);
return x.map(val => (val - mean) / std);
}
}
class Tokenizer {
constructor(vocab) {
this.vocab = vocab;
this.reverseVocab = Object.fromEntries(
Object.entries(vocab).map(([k, v]) => [v, k])
);
this.vocabSize = Object.keys(vocab).length;
}
encode(text) {
return text.toLowerCase()
.split(/\s+/)
.filter(w => w.length > 0)
.map(w => this.vocab[w] ?? this.vocab["<unk>"]);
}
decode(tokens) {
return tokens.map(t => this.reverseVocab[t] || "<unk>").join(" ");
}
}
class MiniTransformer {
constructor(weights) {
this.vocabSize = weights.vocabSize;
this.embedDim = weights.embedDim;
this.hiddenDim = weights.hiddenDim;
this.numLayers = weights.numLayers;
this.embedding = weights.embedding;
this.layers = weights.layers;
this.outputWeights = weights.outputWeights;
}
embed(tokenId) {
return [...this.embedding[tokenId]];
}
forward(tokens) {
const embeddings = tokens.map(t => this.embed(t));
let x = embeddings[embeddings.length - 1];
for (const layer of this.layers) {
const attnOut = this.attention(x, layer.attention);
x = MathUtils.add(x, attnOut);
x = MathUtils.layerNorm(x);
const mlpOut = this.mlp(x, layer.mlp);
x = MathUtils.add(x, mlpOut);
x = MathUtils.layerNorm(x);
}
const logits = MathUtils.vecMatmul(x, this.outputWeights);
return MathUtils.softmax(logits);
}
attention(x, attnWeights) {
const q = MathUtils.vecMatmul(x, attnWeights.wq);
const k = MathUtils.vecMatmul(x, attnWeights.wk);
const v = MathUtils.vecMatmul(x, attnWeights.wv);
const score = q.reduce((sum, val, i) => sum + val * k[i], 0);
const attn = 1.0;
const context = v.map(val => val * attn);
return MathUtils.vecMatmul(context, attnWeights.wo);
}
mlp(x, mlpWeights) {
let hidden = MathUtils.vecMatmul(x, mlpWeights.w1);
hidden = MathUtils.add(hidden, mlpWeights.b1);
hidden = MathUtils.relu(hidden);
let output = MathUtils.vecMatmul(hidden, mlpWeights.w2);
output = MathUtils.add(output, mlpWeights.b2);
return output;
}
generate(tokens, maxTokens = 20, temperature = 0.8, topK = 10, repetitionPenalty = 1.2) {
const generated = [...tokens];
for (let i = 0; i < maxTokens; i++) {
const contextTokens = generated.slice(-5);
let probs = this.forward(contextTokens);
for (let j = 0; j < probs.length; j++) {
if (generated.includes(j)) probs[j] /= repetitionPenalty;
}
const entropy = -probs.reduce((a, p) => a + (p > 0 ? p * Math.log(p) : 0), 0);
const adaptiveTemp = Math.max(0.5, Math.min(1.2, temperature * (entropy + 0.5)));
probs = probs.map(p => Math.pow(p, 1 / adaptiveTemp));
const sum = probs.reduce((a, b) => a + b, 0);
probs = probs.map(p => p / sum);
const topIndices = probs
.map((p, i) => ({ prob: p, index: i }))
.sort((a, b) => b.prob - a.prob)
.slice(0, topK);
const totalProb = topIndices.reduce((a, b) => a + b.prob, 0);
const topProbs = topIndices.map(item => ({
index: item.index,
prob: item.prob / totalProb
}));
const nextToken = this.sampleFromProbs(topProbs);
generated.push(nextToken);
if (nextToken === 2 || nextToken === 0) break;
}
return generated;
}
sampleFromProbs(topProbs) {
const rand = Math.random();
let cumSum = 0;
for (const item of topProbs) {
cumSum += item.prob;
if (rand < cumSum) return item.index;
}
return topProbs[topProbs.length - 1].index;
}
}
async function interactiveChat() {
console.log("\n🤖 TinyChat Model - Interactive Chat");
console.log("=" .repeat(60));
console.log("\n📖 Loading tokenizer...");
const tokenizerData = JSON.parse(fs.readFileSync('tinychat_tokenizer.json', 'utf8'));
const tokenizer = new Tokenizer(tokenizerData.vocab);
console.log(`✅ Vocabulary: ${tokenizer.vocabSize} tokens`);
console.log("🧠 Loading model weights...");
const weights = JSON.parse(fs.readFileSync('tinychat_weights.json', 'utf8'));
const model = new MiniTransformer(weights);
console.log(`✅ Model loaded (${weights.embedDim}D, ${weights.numLayers} layers)`);
console.log("\n" + "=" .repeat(60));
console.log("💬 Chat with your AI! (type 'quit' to exit)");
console.log("💡 Tips:");
console.log(" - Try prompts from your training data");
console.log(" - Use 2-4 words for best results");
console.log(" - Model may repeat or produce gibberish (it's small!)");
console.log("=" .repeat(60) + "\n");
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
});
const askQuestion = () => {
rl.question('You: ', (input) => {
const prompt = input.trim();
if (prompt.toLowerCase() === 'quit' || prompt.toLowerCase() === 'exit') {
console.log("\n👋 Goodbye!\n");
rl.close();
return;
}
if (prompt.length === 0) {
askQuestion();
return;
}
const tokens = tokenizer.encode(prompt);
if (tokens.length === 0) {
console.log("Bot: [Unable to understand - try different words]\n");
askQuestion();
return;
}
const generated = model.generate(
tokens,
maxTokens = 8,
temperature = 0.3,
topK = 3
);
const response = tokenizer.decode(generated);
console.log(`Bot: ${response}\n`);
askQuestion();
});
};
askQuestion();
}
function main() {
interactiveChat();
}
main();