Spaces:
Running
Running
File size: 7,630 Bytes
ec160fc f2e1fb8 3ebfd79 ec160fc 3ebfd79 ec160fc 3ebfd79 ec160fc 3ebfd79 ec160fc 3ebfd79 f2e1fb8 4656699 ec160fc f2e1fb8 4656699 f2e1fb8 ec160fc f2e1fb8 3d2b0c3 f2e1fb8 4656699 f2e1fb8 ec160fc 4656699 048e22e 4656699 048e22e f2e1fb8 4656699 3ebfd79 4656699 f2e1fb8 4656699 f2e1fb8 4656699 3ebfd79 4656699 f2e1fb8 4656699 f2e1fb8 048e22e f2e1fb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | // Handles the heatmap event, group similarity logic, and text reordering for cluster visualization
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
import { plotHeatmap } from './plotting.js';
const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";
// Cosine similarity between two vectors
function cosine(a, b) {
let dot = 0, na = 0, nb = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
return na && nb ? dot / Math.sqrt(na * nb) : 0;
}
// Remove group headers and split each group into an array of lines (excluding empty lines and headers)
function getCleanGroups(groups) {
return groups.map(g =>
g.split("\n").filter(l => l && !l.startsWith("##"))
);
}
// Flatten all lines from all groups and get their embeddings
async function getAllLinesAndEmbeds(cleanGroups, task) {
const allLines = cleanGroups.flat();
const allEmbeds = await getLineEmbeddings(allLines, task);
return { allLines, allEmbeds };
}
// Build an index mapping for each group to map group-relative indices to global indices
function getIdxByGroup(cleanGroups) {
const idxByGroup = [];
let p = 0;
for (const g of cleanGroups) {
idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i));
p += g.length;
}
return idxByGroup;
}
// Build the final output text for reordered groups, including headers and sorted lines
function buildFinalText(order, sortedLines, clusterNames, n) {
return order
.map((gIdx, i) => {
const header =
clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`;
return `## ${header}\n${sortedLines[i].join("\n")}`;
})
.join("\n\n\n");
}
export async function handleHeatmapEvent() {
const progressBar = document.getElementById("progress-bar");
const progressBarInner = document.getElementById("progress-bar-inner");
progressBar.style.display = "block";
progressBarInner.style.width = "0%";
const text = document.getElementById("input").value;
// Get search group from dedicated input (do not use ##search in main input)
const searchGroupText = document.getElementById("search-group-input")?.value.trim();
// Get search sort mode from dropdown (either 'line' or 'group')
const searchSortMode = document.getElementById("search-sort-mode")?.value || "group";
const search_by_max_search_line = searchSortMode === "line";
const search_by_max_search_group = searchSortMode === "group";
// Parse cluster names from main input (ignore any ##search)
const clusterNames = text.split(/\n/)
.map(x => x.trim())
.filter(x => x && x.startsWith('##'))
.map(x => x.replace(/^##\s*/, ''));
const groups = text.split(/\n{3,}/);
// Get group embeddings (removes ## lines internally)
const groupEmbeddings = await getGroupEmbeddings(groups, task);
const n = groupEmbeddings.length;
progressBarInner.style.width = "30%";
// Compute cosine similarity matrix between all group embeddings
const sim = [];
for (let i = 0; i < n; i++) {
const row = [];
for (let j = 0; j < n; j++) {
let dot = 0, na = 0, nb = 0;
for (let k = 0; k < groupEmbeddings[i].length; k++) {
dot += groupEmbeddings[i][k] * groupEmbeddings[j][k];
na += groupEmbeddings[i][k] ** 2;
nb += groupEmbeddings[j][k] ** 2;
}
row.push(dot / Math.sqrt(na * nb));
}
sim.push(row);
}
progressBarInner.style.width = "60%";
// Always use all group indices in order
let order = Array.from({ length: n }, (_, i) => i);
// Only use search group if provided in search-group-input
let useSearchGroup = !!searchGroupText;
let searchIdx = -1;
let searchLines = [];
let searchEmbeds = [];
let refEmbed = null;
if (useSearchGroup) {
searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l);
if (searchLines.length > 0) {
searchEmbeds = await getLineEmbeddings(searchLines, task);
// For group similarity, use the mean embedding of the search group
refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length);
// Compute similarity to each group
const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) }));
simToSearch.sort((a, b) => b.sim - a.sim);
order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)];
}
}
// Reorder sim matrix and clusterNames for heatmap visualization
let simOrdered, xLabels;
if (useSearchGroup && searchLines.length > 0) {
// Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups
simOrdered = [
[1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))],
...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])])
];
xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)];
} else {
simOrdered = order.map(i => order.map(j => sim[i][j]));
xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`);
}
// If search group is provided and sorting by line, reorder lines in each group by similarity to search lines
if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) {
const cleanGroups = getCleanGroups(groups);
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
const idxByGroup = getIdxByGroup(cleanGroups);
const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e)));
// Skip -1 (search group) in order for main input reordering
const sorted = (order[0] === -1 ? order.slice(1) : order).map(g =>
idxByGroup[g]
.map(i => ({ t: allLines[i], s: score(allEmbeds[i]) }))
.sort((a, b) => b.s - a.s)
.map(o => o.t)
);
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n);
document.getElementById("input").value = finalText;
}
// If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding
if (useSearchGroup && search_by_max_search_group && refEmbed) {
const cleanGroups = getCleanGroups(groups);
const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
const idxByGroup = getIdxByGroup(cleanGroups);
// Skip -1 (search group) in order for main input reordering
const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx =>
idxByGroup[gIdx]
.map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) }))
.sort((a, b) => b.s - a.s)
.map(o => o.t)
);
const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n);
document.getElementById("input").value = finalText;
}
plotHeatmap(simOrdered, xLabels, xLabels);
progressBarInner.style.width = "100%";
}
|