PyxiLabs commited on
Commit
7904d8b
·
verified ·
1 Parent(s): c72af18

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tinychat_weights.json filter=lfs diff=lfs merge=lfs -text
index (2).js ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const fs = require('fs');
2
+ const readline = require('readline');
3
+
4
+ class MathUtils {
5
+ static vecMatmul(vec, mat) {
6
+ const m = mat[0].length;
7
+ const result = Array(m).fill(0);
8
+ for (let j = 0; j < m; j++) {
9
+ for (let i = 0; i < vec.length; i++) {
10
+ result[j] += vec[i] * mat[i][j];
11
+ }
12
+ }
13
+ return result;
14
+ }
15
+
16
+ static add(a, b) {
17
+ return a.map((val, i) => val + b[i]);
18
+ }
19
+
20
+ static relu(x) {
21
+ return x.map(val => Math.max(0, val));
22
+ }
23
+
24
+ static softmax(logits) {
25
+ const maxLogit = Math.max(...logits);
26
+ const expValues = logits.map(x => Math.exp(x - maxLogit));
27
+ const sumExp = expValues.reduce((a, b) => a + b, 0);
28
+ return expValues.map(x => x / sumExp);
29
+ }
30
+
31
+ static layerNorm(x) {
32
+ const mean = x.reduce((a, b) => a + b, 0) / x.length;
33
+ const variance = x.reduce((a, b) => a + (b - mean) ** 2, 0) / x.length;
34
+ const std = Math.sqrt(variance + 1e-5);
35
+ return x.map(val => (val - mean) / std);
36
+ }
37
+ }
38
+
39
+ class Tokenizer {
40
+ constructor(vocab) {
41
+ this.vocab = vocab;
42
+ this.reverseVocab = Object.fromEntries(
43
+ Object.entries(vocab).map(([k, v]) => [v, k])
44
+ );
45
+ this.vocabSize = Object.keys(vocab).length;
46
+ }
47
+
48
+ encode(text) {
49
+ return text.toLowerCase()
50
+ .split(/\s+/)
51
+ .filter(w => w.length > 0)
52
+ .map(w => this.vocab[w] ?? this.vocab["<unk>"]);
53
+ }
54
+
55
+ decode(tokens) {
56
+ return tokens.map(t => this.reverseVocab[t] || "<unk>").join(" ");
57
+ }
58
+ }
59
+
60
+ class MiniTransformer {
61
+ constructor(weights) {
62
+ this.vocabSize = weights.vocabSize;
63
+ this.embedDim = weights.embedDim;
64
+ this.hiddenDim = weights.hiddenDim;
65
+ this.numLayers = weights.numLayers;
66
+
67
+ this.embedding = weights.embedding;
68
+ this.layers = weights.layers;
69
+ this.outputWeights = weights.outputWeights;
70
+ }
71
+
72
+ embed(tokenId) {
73
+ return [...this.embedding[tokenId]];
74
+ }
75
+
76
+ forward(tokens) {
77
+ const embeddings = tokens.map(t => this.embed(t));
78
+ let x = embeddings[embeddings.length - 1];
79
+
80
+ for (const layer of this.layers) {
81
+
82
+ const attnOut = this.attention(x, layer.attention);
83
+ x = MathUtils.add(x, attnOut);
84
+ x = MathUtils.layerNorm(x);
85
+
86
+ const mlpOut = this.mlp(x, layer.mlp);
87
+ x = MathUtils.add(x, mlpOut);
88
+ x = MathUtils.layerNorm(x);
89
+ }
90
+
91
+ const logits = MathUtils.vecMatmul(x, this.outputWeights);
92
+ return MathUtils.softmax(logits);
93
+ }
94
+
95
+ attention(x, attnWeights) {
96
+ const q = MathUtils.vecMatmul(x, attnWeights.wq);
97
+ const k = MathUtils.vecMatmul(x, attnWeights.wk);
98
+ const v = MathUtils.vecMatmul(x, attnWeights.wv);
99
+
100
+ const score = q.reduce((sum, val, i) => sum + val * k[i], 0);
101
+ const attn = 1.0;
102
+
103
+ const context = v.map(val => val * attn);
104
+ return MathUtils.vecMatmul(context, attnWeights.wo);
105
+ }
106
+
107
+ mlp(x, mlpWeights) {
108
+ let hidden = MathUtils.vecMatmul(x, mlpWeights.w1);
109
+ hidden = MathUtils.add(hidden, mlpWeights.b1);
110
+ hidden = MathUtils.relu(hidden);
111
+
112
+ let output = MathUtils.vecMatmul(hidden, mlpWeights.w2);
113
+ output = MathUtils.add(output, mlpWeights.b2);
114
+ return output;
115
+ }
116
+
117
+ generate(tokens, maxTokens = 20, temperature = 0.8, topK = 10, repetitionPenalty = 1.2) {
118
+ const generated = [...tokens];
119
+
120
+ for (let i = 0; i < maxTokens; i++) {
121
+ const contextTokens = generated.slice(-5);
122
+ let probs = this.forward(contextTokens);
123
+
124
+ for (let j = 0; j < probs.length; j++) {
125
+ if (generated.includes(j)) probs[j] /= repetitionPenalty;
126
+ }
127
+
128
+ const entropy = -probs.reduce((a, p) => a + (p > 0 ? p * Math.log(p) : 0), 0);
129
+ const adaptiveTemp = Math.max(0.5, Math.min(1.2, temperature * (entropy + 0.5)));
130
+
131
+ probs = probs.map(p => Math.pow(p, 1 / adaptiveTemp));
132
+ const sum = probs.reduce((a, b) => a + b, 0);
133
+ probs = probs.map(p => p / sum);
134
+
135
+ const topIndices = probs
136
+ .map((p, i) => ({ prob: p, index: i }))
137
+ .sort((a, b) => b.prob - a.prob)
138
+ .slice(0, topK);
139
+
140
+ const totalProb = topIndices.reduce((a, b) => a + b.prob, 0);
141
+ const topProbs = topIndices.map(item => ({
142
+ index: item.index,
143
+ prob: item.prob / totalProb
144
+ }));
145
+
146
+ const nextToken = this.sampleFromProbs(topProbs);
147
+ generated.push(nextToken);
148
+
149
+ if (nextToken === 2 || nextToken === 0) break;
150
+ }
151
+
152
+ return generated;
153
+ }
154
+
155
+ sampleFromProbs(topProbs) {
156
+ const rand = Math.random();
157
+ let cumSum = 0;
158
+
159
+ for (const item of topProbs) {
160
+ cumSum += item.prob;
161
+ if (rand < cumSum) return item.index;
162
+ }
163
+
164
+ return topProbs[topProbs.length - 1].index;
165
+ }
166
+ }
167
+
168
+ async function interactiveChat() {
169
+ console.log("\n🤖 TinyChat Model - Interactive Chat");
170
+ console.log("=" .repeat(60));
171
+
172
+ console.log("\n📖 Loading tokenizer...");
173
+ const tokenizerData = JSON.parse(fs.readFileSync('tinychat_tokenizer.json', 'utf8'));
174
+ const tokenizer = new Tokenizer(tokenizerData.vocab);
175
+ console.log(`✅ Vocabulary: ${tokenizer.vocabSize} tokens`);
176
+
177
+ console.log("🧠 Loading model weights...");
178
+ const weights = JSON.parse(fs.readFileSync('tinychat_weights.json', 'utf8'));
179
+ const model = new MiniTransformer(weights);
180
+ console.log(`✅ Model loaded (${weights.embedDim}D, ${weights.numLayers} layers)`);
181
+
182
+ console.log("\n" + "=" .repeat(60));
183
+ console.log("💬 Chat with your AI! (type 'quit' to exit)");
184
+ console.log("💡 Tips:");
185
+ console.log(" - Try prompts from your training data");
186
+ console.log(" - Use 2-4 words for best results");
187
+ console.log(" - Model may repeat or produce gibberish (it's small!)");
188
+ console.log("=" .repeat(60) + "\n");
189
+
190
+ const rl = readline.createInterface({
191
+ input: process.stdin,
192
+ output: process.stdout
193
+ });
194
+
195
+ const askQuestion = () => {
196
+ rl.question('You: ', (input) => {
197
+ const prompt = input.trim();
198
+
199
+ if (prompt.toLowerCase() === 'quit' || prompt.toLowerCase() === 'exit') {
200
+ console.log("\n👋 Goodbye!\n");
201
+ rl.close();
202
+ return;
203
+ }
204
+
205
+ if (prompt.length === 0) {
206
+ askQuestion();
207
+ return;
208
+ }
209
+
210
+ const tokens = tokenizer.encode(prompt);
211
+
212
+ if (tokens.length === 0) {
213
+ console.log("Bot: [Unable to understand - try different words]\n");
214
+ askQuestion();
215
+ return;
216
+ }
217
+
218
+ const generated = model.generate(
219
+ tokens,
220
+ maxTokens = 8,
221
+ temperature = 0.3,
222
+ topK = 3
223
+ );
224
+
225
+ const response = tokenizer.decode(generated);
226
+ console.log(`Bot: ${response}\n`);
227
+
228
+ askQuestion();
229
+ });
230
+ };
231
+
232
+ askQuestion();
233
+ }
234
+
235
+ function main() {
236
+ interactiveChat();
237
+ }
238
+
239
+ main();
tinychat_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tinychat_weights.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29f2b87de2a09da299c68cec7f8d63994d5426d808566211ec2144a2e1b3d7c4
3
+ size 17465721