ping98k commited on
Commit
ee4ca8c
·
1 Parent(s): f08e6a1

Add initial scatter plot with placeholder names for K-Means clustering results

Browse files
Files changed (2) hide show
  1. index.html +1 -177
  2. main.js +25 -14
index.html CHANGED
@@ -64,183 +64,7 @@
64
  <div id="plot-heatmap" style="width:500px; height:500px;"></div>
65
  </div>
66
  <script src="https://cdn.plot.ly/plotly-2.32.0.min.js"></script>
67
- <script type="module">
68
- import { pipeline, TextStreamer, AutoTokenizer, AutoModelForCausalLM } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.0';
69
- import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
70
-
71
- const embed = await pipeline(
72
- "feature-extraction",
73
- "onnx-community/Qwen3-Embedding-0.6B-ONNX",
74
- { device: "webgpu", dtype: "q4f16" },
75
- );
76
- const tokenizer = await AutoTokenizer.from_pretrained("onnx-community/Qwen3-0.6B-ONNX");
77
- const model = await AutoModelForCausalLM.from_pretrained("onnx-community/Qwen3-0.6B-ONNX", { device: "webgpu", dtype: "q4f16" });
78
-
79
- const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";
80
- document.getElementById("run").onclick = async () => {
81
- const text = document.getElementById("input").value;
82
- const groups = text.split(/\n{3,}/);
83
- const groupEmbeddings = [];
84
- for (const g of groups) {
85
- const lines = g.split(/\n/).filter(x => x.trim() != "");
86
- const prompts = lines.map(s => `Instruct: ${task}\nQuery:${s}`);
87
- const out = await embed(prompts, { pooling: "mean", normalize: true });
88
- const embeddings = typeof out.tolist === 'function' ? out.tolist() : out.data;
89
- const dim = embeddings[0].length;
90
- const avg = new Float32Array(dim);
91
- for (const e of embeddings) { for (let i = 0; i < dim; i++) avg[i] += e[i]; }
92
- for (let i = 0; i < dim; i++) avg[i] /= embeddings.length;
93
- groupEmbeddings.push(avg);
94
- }
95
- const n = groupEmbeddings.length;
96
- const sim = [];
97
- for (let i = 0; i < n; i++) {
98
- const row = [];
99
- for (let j = 0; j < n; j++) {
100
- let dot = 0, na = 0, nb = 0;
101
- for (let k = 0; k < groupEmbeddings[i].length; k++) {
102
- dot += groupEmbeddings[i][k] * groupEmbeddings[j][k];
103
- na += groupEmbeddings[i][k] ** 2;
104
- nb += groupEmbeddings[j][k] ** 2;
105
- }
106
- row.push(dot / Math.sqrt(na * nb));
107
- }
108
- sim.push(row);
109
- }
110
- const data = [{ z: sim, type: "heatmap", colorscale: "Viridis", zmin: 0, zmax: 1 }];
111
- Plotly.newPlot("plot-heatmap", data, {
112
- xaxis: { title: "Group", scaleanchor: "y", scaleratio: 1 },
113
- yaxis: { title: "Group", scaleanchor: "x", scaleratio: 1 },
114
- width: 500,
115
- height: 500,
116
- margin: { t: 40, l: 40, r: 10, b: 40 },
117
- title: "Group Similarity Heatmap"
118
- });
119
- };
120
-
121
- // --- K-Means Clustering ---
122
- document.getElementById("kmeans-btn").onclick = async () => {
123
- const progressBar = document.getElementById("progress-bar");
124
- const progressBarInner = document.getElementById("progress-bar-inner");
125
- progressBar.style.display = "block";
126
- progressBarInner.style.width = "0%";
127
-
128
- const text = document.getElementById("input").value;
129
- const lines = text.split(/\n/).map(x => x.trim()).filter(x => x);
130
- const prompts = lines.map(s => `Instruct: ${task}\nQuery:${s}`);
131
- const out = await embed(prompts, { pooling: "mean", normalize: true });
132
- const embeddings = typeof out.tolist === 'function' ? out.tolist() : out.data;
133
-
134
- // K-Means implementation
135
- const k = Math.max(2, Math.min(20, parseInt(document.getElementById("kmeans-k").value) || 3));
136
- const n = embeddings.length, dim = embeddings[0].length;
137
- let centroids = Array.from({ length: k }, () => embeddings[Math.floor(Math.random() * n)].slice());
138
- let labels = new Array(n).fill(0);
139
- for (let iter = 0; iter < 20; ++iter) {
140
- for (let i = 0; i < n; ++i) {
141
- let best = 0, bestDist = Infinity;
142
- for (let c = 0; c < k; ++c) {
143
- let dist = 0;
144
- for (let d = 0; d < dim; ++d) dist += (embeddings[i][d] - centroids[c][d]) ** 2;
145
- if (dist < bestDist) { bestDist = dist; best = c; }
146
- }
147
- labels[i] = best;
148
- }
149
- centroids = Array.from({ length: k }, () => new Array(dim).fill(0));
150
- const counts = new Array(k).fill(0);
151
- for (let i = 0; i < n; ++i) {
152
- counts[labels[i]]++;
153
- for (let d = 0; d < dim; ++d) centroids[labels[i]][d] += embeddings[i][d];
154
- }
155
- for (let c = 0; c < k; ++c) if (counts[c]) for (let d = 0; d < dim; ++d) centroids[c][d] /= counts[c];
156
- }
157
- // UMAP for 2D projection
158
- const umap = new UMAP({ nComponents: 2 });
159
- const proj = umap.fit(embeddings);
160
- // Group lines by cluster
161
- const clustered = Array.from({ length: k }, (_, c) => []);
162
- for (let i = 0; i < n; ++i) clustered[labels[i]].push(lines[i]);
163
- // Generate cluster names using text generation pipeline (async with progress)
164
- const clusterNames = [];
165
- for (let c = 0; c < k; ++c) {
166
- progressBarInner.style.width = `${Math.round(((c) / k) * 100)}%`;
167
- const joined = clustered[c].join("\n");
168
- const messages = [
169
- { role: "system", content: "You are a helpful assistant." },
170
- { role: "user", content: `Given the following texts, provide a short, descriptive name for this group:\n\n${joined}` }
171
- ];
172
- const reasonEnabled = false;
173
- const inputs = tokenizer.apply_chat_template(messages, {
174
- add_generation_prompt: true,
175
- return_dict: true,
176
- enable_thinking: reasonEnabled,
177
- });
178
- const [START_THINKING_TOKEN_ID, END_THINKING_TOKEN_ID] = tokenizer.encode("<think></think>", { add_special_tokens: false });
179
- let state = "answering";
180
- let startTime;
181
- let numTokens = 0;
182
- let tps;
183
- const token_callback_function = (tokens) => {
184
- startTime ??= performance.now();
185
- if (numTokens++ > 0) {
186
- tps = (numTokens / (performance.now() - startTime)) * 1000;
187
- }
188
- switch (Number(tokens[0])) {
189
- case START_THINKING_TOKEN_ID:
190
- state = "thinking";
191
- break;
192
- case END_THINKING_TOKEN_ID:
193
- state = "answering";
194
- break;
195
- }
196
- console.log(state, tokens, tokenizer.decode(tokens));
197
- };
198
- const callback_function = (output) => {
199
- // You can update UI here if desired
200
- console.log({ output, tps, numTokens, state });
201
- };
202
- const streamer = new TextStreamer(tokenizer, {
203
- skip_prompt: true,
204
- skip_special_tokens: true,
205
- callback_function,
206
- token_callback_function,
207
- });
208
- const outputTokens = await model.generate({
209
- ...inputs,
210
- max_new_tokens: 32,
211
- do_sample: false,
212
- streamer,
213
- });
214
- let name = tokenizer.decode(outputTokens[0], { skip_special_tokens: false }).trim();
215
- clusterNames.push(name.length > 0 ? name : `Cluster ${c + 1}`);
216
- }
217
- progressBarInner.style.width = "100%";
218
- setTimeout(() => { progressBar.style.display = "none"; }, 400);
219
- // Plot
220
- const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
221
- const traces = Array.from({ length: k }, (_, c) => ({
222
- x: [], y: [], text: [], mode: "markers", type: "scatter", name: clusterNames[c],
223
- marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: '#333' } }
224
- }));
225
- for (let i = 0; i < n; ++i) {
226
- traces[labels[i]].x.push(proj[i][0]);
227
- traces[labels[i]].y.push(proj[i][1]);
228
- traces[labels[i]].text.push(lines[i]);
229
- }
230
- Plotly.newPlot("plot-scatter", traces, {
231
- xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
232
- yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
233
- width: 1000,
234
- height: 500,
235
- margin: { t: 40, l: 40, r: 10, b: 40 },
236
- title: `K-Means Clustering (k=${k})`
237
- });
238
- // Update textarea: group by cluster, separated by triple newlines
239
- document.getElementById("input").value = clustered.map(g => g.join("\n")).join("\n\n\n");
240
- // Re-run heatmap after updating textarea
241
- document.getElementById("run").onclick();
242
- };
243
- </script>
244
  </body>
245
 
246
  </html>
 
64
  <div id="plot-heatmap" style="width:500px; height:500px;"></div>
65
  </div>
66
  <script src="https://cdn.plot.ly/plotly-2.32.0.min.js"></script>
67
+ <script type="module" src="./main.js"></script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  </body>
69
 
70
  </html>
main.js CHANGED
@@ -94,6 +94,26 @@ document.getElementById("kmeans-btn").onclick = async () => {
94
  // Group lines by cluster
95
  const clustered = Array.from({ length: k }, (_, c) => []);
96
  for (let i = 0; i < n; ++i) clustered[labels[i]].push(lines[i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  // Generate cluster names using text generation pipeline (async with progress)
98
  const clusterNames = [];
99
  for (let c = 0; c < k; ++c) {
@@ -148,23 +168,14 @@ document.getElementById("kmeans-btn").onclick = async () => {
148
  let name = tokenizer.decode(outputTokens[0], { skip_special_tokens: false }).trim();
149
  clusterNames.push(name.length > 0 ? name : `Cluster ${c + 1}`);
150
  }
151
- progressBarInner.style.width = "100%";
152
- setTimeout(() => { progressBar.style.display = "none"; }, 400);
153
- // Plot
154
- const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
155
- const traces = Array.from({ length: k }, (_, c) => ({
156
- x: [], y: [], text: [], mode: "markers", type: "scatter", name: clusterNames[c],
157
- marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: '#333' } }
158
- }));
159
- for (let i = 0; i < n; ++i) {
160
- traces[labels[i]].x.push(proj[i][0]);
161
- traces[labels[i]].y.push(proj[i][1]);
162
- traces[labels[i]].text.push(lines[i]);
163
  }
164
- Plotly.newPlot("plot-scatter", traces, {
165
  xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
166
  yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
167
- width: 1000,
168
  height: 500,
169
  margin: { t: 40, l: 40, r: 10, b: 40 },
170
  title: `K-Means Clustering (k=${k})`
 
94
  // Group lines by cluster
95
  const clustered = Array.from({ length: k }, (_, c) => []);
96
  for (let i = 0; i < n; ++i) clustered[labels[i]].push(lines[i]);
97
+ // Plot initial scatter with placeholder names before clusterNames are ready
98
+ const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
99
+ const placeholderNames = Array.from({ length: k }, (_, c) => `Cluster ${c + 1}`);
100
+ let traces = Array.from({ length: k }, (_, c) => ({
101
+ x: [], y: [], text: [], mode: "markers", type: "scatter", name: placeholderNames[c],
102
+ marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: '#333' } }
103
+ }));
104
+ for (let i = 0; i < n; ++i) {
105
+ traces[labels[i]].x.push(proj[i][0]);
106
+ traces[labels[i]].y.push(proj[i][1]);
107
+ traces[labels[i]].text.push(lines[i]);
108
+ }
109
+ Plotly.newPlot("plot-scatter", traces, {
110
+ xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
111
+ yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
112
+ width: 500,
113
+ height: 500,
114
+ margin: { t: 40, l: 40, r: 10, b: 40 },
115
+ title: `K-Means Clustering (k=${k})`
116
+ });
117
  // Generate cluster names using text generation pipeline (async with progress)
118
  const clusterNames = [];
119
  for (let c = 0; c < k; ++c) {
 
168
  let name = tokenizer.decode(outputTokens[0], { skip_special_tokens: false }).trim();
169
  clusterNames.push(name.length > 0 ? name : `Cluster ${c + 1}`);
170
  }
171
+ // After all names are generated, update the trace names and render once
172
+ for (let c = 0; c < k; ++c) {
173
+ traces[c].name = clusterNames[c];
 
 
 
 
 
 
 
 
 
174
  }
175
+ Plotly.react("plot-scatter", traces, {
176
  xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
177
  yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
178
+ width: 500,
179
  height: 500,
180
  margin: { t: 40, l: 40, r: 10, b: 40 },
181
  title: `K-Means Clustering (k=${k})`