Spaces:
Running
Running
ping98k
Refactor heatmap event handling to update x-axis labels for clarity and ensure heatmap is plotted after processing search group logic.
048e22e | // Handles the heatmap event, group similarity logic, and text reordering for cluster visualization | |
| import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js'; | |
| import { plotHeatmap } from './plotting.js'; | |
| const task = "Given a textual input sentence, retrieve relevant categories that best describe it."; | |
| // Cosine similarity between two vectors | |
| function cosine(a, b) { | |
| let dot = 0, na = 0, nb = 0; | |
| for (let i = 0; i < a.length; i++) { | |
| dot += a[i] * b[i]; | |
| na += a[i] * a[i]; | |
| nb += b[i] * b[i]; | |
| } | |
| return na && nb ? dot / Math.sqrt(na * nb) : 0; | |
| } | |
| // Remove group headers and split each group into an array of lines (excluding empty lines and headers) | |
| function getCleanGroups(groups) { | |
| return groups.map(g => | |
| g.split("\n").filter(l => l && !l.startsWith("##")) | |
| ); | |
| } | |
| // Flatten all lines from all groups and get their embeddings | |
| async function getAllLinesAndEmbeds(cleanGroups, task) { | |
| const allLines = cleanGroups.flat(); | |
| const allEmbeds = await getLineEmbeddings(allLines, task); | |
| return { allLines, allEmbeds }; | |
| } | |
| // Build an index mapping for each group to map group-relative indices to global indices | |
| function getIdxByGroup(cleanGroups) { | |
| const idxByGroup = []; | |
| let p = 0; | |
| for (const g of cleanGroups) { | |
| idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i)); | |
| p += g.length; | |
| } | |
| return idxByGroup; | |
| } | |
| // Build the final output text for reordered groups, including headers and sorted lines | |
| function buildFinalText(order, sortedLines, clusterNames, n) { | |
| return order | |
| .map((gIdx, i) => { | |
| const header = | |
| clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`; | |
| return `## ${header}\n${sortedLines[i].join("\n")}`; | |
| }) | |
| .join("\n\n\n"); | |
| } | |
| export async function handleHeatmapEvent() { | |
| const progressBar = document.getElementById("progress-bar"); | |
| const progressBarInner = document.getElementById("progress-bar-inner"); | |
| progressBar.style.display = "block"; | |
| progressBarInner.style.width = "0%"; | |
| const text = document.getElementById("input").value; | |
| // Get search group from dedicated input (do not use ##search in main input) | |
| const searchGroupText = document.getElementById("search-group-input")?.value.trim(); | |
| // Get search sort mode from dropdown (either 'line' or 'group') | |
| const searchSortMode = document.getElementById("search-sort-mode")?.value || "group"; | |
| const search_by_max_search_line = searchSortMode === "line"; | |
| const search_by_max_search_group = searchSortMode === "group"; | |
| // Parse cluster names from main input (ignore any ##search) | |
| const clusterNames = text.split(/\n/) | |
| .map(x => x.trim()) | |
| .filter(x => x && x.startsWith('##')) | |
| .map(x => x.replace(/^##\s*/, '')); | |
| const groups = text.split(/\n{3,}/); | |
| // Get group embeddings (removes ## lines internally) | |
| const groupEmbeddings = await getGroupEmbeddings(groups, task); | |
| const n = groupEmbeddings.length; | |
| progressBarInner.style.width = "30%"; | |
| // Compute cosine similarity matrix between all group embeddings | |
| const sim = []; | |
| for (let i = 0; i < n; i++) { | |
| const row = []; | |
| for (let j = 0; j < n; j++) { | |
| let dot = 0, na = 0, nb = 0; | |
| for (let k = 0; k < groupEmbeddings[i].length; k++) { | |
| dot += groupEmbeddings[i][k] * groupEmbeddings[j][k]; | |
| na += groupEmbeddings[i][k] ** 2; | |
| nb += groupEmbeddings[j][k] ** 2; | |
| } | |
| row.push(dot / Math.sqrt(na * nb)); | |
| } | |
| sim.push(row); | |
| } | |
| progressBarInner.style.width = "60%"; | |
| // Always use all group indices in order | |
| let order = Array.from({ length: n }, (_, i) => i); | |
| // Only use search group if provided in search-group-input | |
| let useSearchGroup = !!searchGroupText; | |
| let searchIdx = -1; | |
| let searchLines = []; | |
| let searchEmbeds = []; | |
| let refEmbed = null; | |
| if (useSearchGroup) { | |
| searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l); | |
| if (searchLines.length > 0) { | |
| searchEmbeds = await getLineEmbeddings(searchLines, task); | |
| // For group similarity, use the mean embedding of the search group | |
| refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length); | |
| // Compute similarity to each group | |
| const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) })); | |
| simToSearch.sort((a, b) => b.sim - a.sim); | |
| order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)]; | |
| } | |
| } | |
| // Reorder sim matrix and clusterNames for heatmap visualization | |
| let simOrdered, xLabels; | |
| if (useSearchGroup && searchLines.length > 0) { | |
| // Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups | |
| simOrdered = [ | |
| [1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))], | |
| ...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])]) | |
| ]; | |
| xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)]; | |
| } else { | |
| simOrdered = order.map(i => order.map(j => sim[i][j])); | |
| xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`); | |
| } | |
| // If search group is provided and sorting by line, reorder lines in each group by similarity to search lines | |
| if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) { | |
| const cleanGroups = getCleanGroups(groups); | |
| const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); | |
| const idxByGroup = getIdxByGroup(cleanGroups); | |
| const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e))); | |
| // Skip -1 (search group) in order for main input reordering | |
| const sorted = (order[0] === -1 ? order.slice(1) : order).map(g => | |
| idxByGroup[g] | |
| .map(i => ({ t: allLines[i], s: score(allEmbeds[i]) })) | |
| .sort((a, b) => b.s - a.s) | |
| .map(o => o.t) | |
| ); | |
| const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n); | |
| document.getElementById("input").value = finalText; | |
| } | |
| // If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding | |
| if (useSearchGroup && search_by_max_search_group && refEmbed) { | |
| const cleanGroups = getCleanGroups(groups); | |
| const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task); | |
| const idxByGroup = getIdxByGroup(cleanGroups); | |
| // Skip -1 (search group) in order for main input reordering | |
| const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx => | |
| idxByGroup[gIdx] | |
| .map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) })) | |
| .sort((a, b) => b.s - a.s) | |
| .map(o => o.t) | |
| ); | |
| const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n); | |
| document.getElementById("input").value = finalText; | |
| } | |
| plotHeatmap(simOrdered, xLabels, xLabels); | |
| progressBarInner.style.width = "100%"; | |
| } | |