File size: 7,630 Bytes
ec160fc
f2e1fb8
 
 
 
 
3ebfd79
 
 
 
 
 
 
 
 
 
 
ec160fc
3ebfd79
 
 
 
 
 
ec160fc
3ebfd79
 
 
 
 
 
ec160fc
3ebfd79
 
 
 
 
 
 
 
 
 
ec160fc
3ebfd79
 
 
 
 
 
 
 
 
 
f2e1fb8
 
 
 
 
 
 
4656699
 
ec160fc
f2e1fb8
 
 
 
4656699
f2e1fb8
 
 
 
 
 
 
 
 
 
ec160fc
f2e1fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d2b0c3
f2e1fb8
4656699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2e1fb8
ec160fc
4656699
 
 
 
 
 
 
048e22e
4656699
 
 
 
048e22e
f2e1fb8
4656699
 
3ebfd79
 
 
 
4656699
 
f2e1fb8
 
 
 
 
4656699
f2e1fb8
 
 
4656699
 
3ebfd79
 
 
4656699
 
f2e1fb8
 
 
 
 
4656699
f2e1fb8
 
048e22e
 
f2e1fb8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Handles the heatmap event, group similarity logic, and text reordering for cluster visualization
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
import { plotHeatmap } from './plotting.js';

const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";

// Cosine similarity between two vectors
function cosine(a, b) {
    let dot = 0, na = 0, nb = 0;
    for (let i = 0; i < a.length; i++) {
        dot += a[i] * b[i];
        na += a[i] * a[i];
        nb += b[i] * b[i];
    }
    return na && nb ? dot / Math.sqrt(na * nb) : 0;
}

// Remove group headers and split each group into an array of lines (excluding empty lines and headers)
function getCleanGroups(groups) {
    return groups.map(g =>
        g.split("\n").filter(l => l && !l.startsWith("##"))
    );
}

// Flatten all lines from all groups and get their embeddings
async function getAllLinesAndEmbeds(cleanGroups, task) {
    const allLines = cleanGroups.flat();
    const allEmbeds = await getLineEmbeddings(allLines, task);
    return { allLines, allEmbeds };
}

// Build an index mapping for each group to map group-relative indices to global indices
function getIdxByGroup(cleanGroups) {
    const idxByGroup = [];
    let p = 0;
    for (const g of cleanGroups) {
        idxByGroup.push(Array.from({ length: g.length }, (_, i) => p + i));
        p += g.length;
    }
    return idxByGroup;
}

// Build the final output text for reordered groups, including headers and sorted lines
function buildFinalText(order, sortedLines, clusterNames, n) {
    return order
        .map((gIdx, i) => {
            const header =
                clusterNames?.length === n ? clusterNames[gIdx] : `Group ${i + 1}`;
            return `## ${header}\n${sortedLines[i].join("\n")}`;
        })
        .join("\n\n\n");
}

export async function handleHeatmapEvent() {
    const progressBar = document.getElementById("progress-bar");
    const progressBarInner = document.getElementById("progress-bar-inner");
    progressBar.style.display = "block";
    progressBarInner.style.width = "0%";

    const text = document.getElementById("input").value;
    // Get search group from dedicated input (do not use ##search in main input)
    const searchGroupText = document.getElementById("search-group-input")?.value.trim();
    // Get search sort mode from dropdown (either 'line' or 'group')
    const searchSortMode = document.getElementById("search-sort-mode")?.value || "group";
    const search_by_max_search_line = searchSortMode === "line";
    const search_by_max_search_group = searchSortMode === "group";

    // Parse cluster names from main input (ignore any ##search)
    const clusterNames = text.split(/\n/)
        .map(x => x.trim())
        .filter(x => x && x.startsWith('##'))
        .map(x => x.replace(/^##\s*/, ''));

    const groups = text.split(/\n{3,}/);
    // Get group embeddings (removes ## lines internally)
    const groupEmbeddings = await getGroupEmbeddings(groups, task);
    const n = groupEmbeddings.length;
    progressBarInner.style.width = "30%";
    // Compute cosine similarity matrix between all group embeddings
    const sim = [];
    for (let i = 0; i < n; i++) {
        const row = [];
        for (let j = 0; j < n; j++) {
            let dot = 0, na = 0, nb = 0;
            for (let k = 0; k < groupEmbeddings[i].length; k++) {
                dot += groupEmbeddings[i][k] * groupEmbeddings[j][k];
                na += groupEmbeddings[i][k] ** 2;
                nb += groupEmbeddings[j][k] ** 2;
            }
            row.push(dot / Math.sqrt(na * nb));
        }
        sim.push(row);
    }
    progressBarInner.style.width = "60%";
    // Always use all group indices in order
    let order = Array.from({ length: n }, (_, i) => i);

    // Only use search group if provided in search-group-input
    let useSearchGroup = !!searchGroupText;
    let searchIdx = -1;
    let searchLines = [];
    let searchEmbeds = [];
    let refEmbed = null;
    if (useSearchGroup) {
        searchLines = searchGroupText.split(/\n/).map(l => l.trim()).filter(l => l);
        if (searchLines.length > 0) {
            searchEmbeds = await getLineEmbeddings(searchLines, task);
            // For group similarity, use the mean embedding of the search group
            refEmbed = searchEmbeds[0].map((_, i) => searchEmbeds.reduce((sum, e) => sum + e[i], 0) / searchEmbeds.length);
            // Compute similarity to each group
            const simToSearch = groupEmbeddings.map((emb, i) => ({ idx: i, sim: cosine(refEmbed, emb) }));
            simToSearch.sort((a, b) => b.sim - a.sim);
            order = [/* search group is not in groupEmbeddings, so just prepend -1 for heatmap */ -1, ...simToSearch.map(x => x.idx)];
        }
    }
    // Reorder sim matrix and clusterNames for heatmap visualization
    let simOrdered, xLabels;
    if (useSearchGroup && searchLines.length > 0) {
        // Insert search group as first row/col in heatmap, with similarity 1 to itself and to other groups
        simOrdered = [
            [1, ...order.slice(1).map(idx => idx === -1 ? 1 : cosine(refEmbed, groupEmbeddings[idx]))],
            ...order.slice(1).map(i => [cosine(refEmbed, groupEmbeddings[i]), ...order.slice(1).map(j => sim[i][j])])
        ];
        xLabels = ["Search", ...order.slice(1).map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`)];
    } else {
        simOrdered = order.map(i => order.map(j => sim[i][j]));
        xLabels = order.map(i => (clusterNames && clusterNames[i]) ? clusterNames[i] : `Group ${i + 1}`);
    }
    

    // If search group is provided and sorting by line, reorder lines in each group by similarity to search lines
    if (useSearchGroup && search_by_max_search_line && searchEmbeds.length > 0) {
        const cleanGroups = getCleanGroups(groups);
        const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
        const idxByGroup = getIdxByGroup(cleanGroups);
        const score = e => Math.max(...searchEmbeds.map(se => cosine(se, e)));
        // Skip -1 (search group) in order for main input reordering
        const sorted = (order[0] === -1 ? order.slice(1) : order).map(g =>
            idxByGroup[g]
                .map(i => ({ t: allLines[i], s: score(allEmbeds[i]) }))
                .sort((a, b) => b.s - a.s)
                .map(o => o.t)
        );
        const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sorted, clusterNames, n);
        document.getElementById("input").value = finalText;
    }

    // If search group is provided and sorting by group, reorder lines in each group by similarity to the search group embedding
    if (useSearchGroup && search_by_max_search_group && refEmbed) {
        const cleanGroups = getCleanGroups(groups);
        const { allLines, allEmbeds } = await getAllLinesAndEmbeds(cleanGroups, task);
        const idxByGroup = getIdxByGroup(cleanGroups);
        // Skip -1 (search group) in order for main input reordering
        const sortedLines = (order[0] === -1 ? order.slice(1) : order).map(gIdx =>
            idxByGroup[gIdx]
                .map(i => ({ t: allLines[i], s: cosine(refEmbed, allEmbeds[i]) }))
                .sort((a, b) => b.s - a.s)
                .map(o => o.t)
        );
        const finalText = buildFinalText(order[0] === -1 ? order.slice(1) : order, sortedLines, clusterNames, n);
        document.getElementById("input").value = finalText;
    }

    plotHeatmap(simOrdered, xLabels, xLabels);
    progressBarInner.style.width = "100%";
}