HyperView / frontend /src /lib /labelLegend.ts
morozovdd's picture
feat: add HyperView app for space
23680f2
import type { EmbeddingsData } from "@/types";
import {
FALLBACK_LABEL_COLOR,
MISSING_LABEL_COLOR,
createLabelColorMap,
normalizeLabel,
} from "@/lib/labelColors";
// Past ~20-30 categories, color-as-encoding becomes unreliable for most users.
// We choose a conservative upper bound and fall back to a single color.
export const MAX_DISTINCT_LABEL_COLORS = 20;
export const UNSELECTED_LABEL_ALPHA = 0.12;
export interface ScatterLabelsInfo {
uniqueLabels: string[];
categories: Uint16Array;
palette: string[];
}
function clamp01(v: number): number {
if (v < 0) return 0;
if (v > 1) return 1;
return v;
}
function applyAlphaToHex(color: string, alpha: number): string {
if (!color.startsWith("#")) return color;
const hex = Math.round(clamp01(alpha) * 255)
.toString(16)
.padStart(2, "0");
if (color.length === 7) {
return `${color}${hex}`;
}
if (color.length === 9) {
return `${color.slice(0, 7)}${hex}`;
}
return color;
}
function applyLabelFilterToPalette(params: {
palette: string[];
labels: string[];
labelFilter: string | null;
unselectedAlpha: number;
}): string[] {
const { palette, labels, labelFilter, unselectedAlpha } = params;
if (!labelFilter) return palette;
if (!labels.includes(labelFilter)) return palette;
return palette.map((color, idx) =>
labels[idx] === labelFilter ? color : applyAlphaToHex(color, unselectedAlpha)
);
}
export function buildLabelCounts(embeddings: EmbeddingsData | null): Map<string, number> {
const counts = new Map<string, number>();
if (!embeddings) return counts;
for (const raw of embeddings.labels) {
const l = normalizeLabel(raw);
counts.set(l, (counts.get(l) ?? 0) + 1);
}
return counts;
}
export function getDistinctLabelCount(labelCounts: Map<string, number>): number {
let n = labelCounts.size;
if (labelCounts.has("undefined")) n -= 1;
return n;
}
export function buildLabelUniverse(
datasetLabels: string[],
embeddingsLabels: (string | null)[] | null
): string[] {
const universe: string[] = [];
const seen = new Set<string>();
let hasMissing = false;
const baseLabels = datasetLabels.map((l) => normalizeLabel(l));
for (const l of baseLabels) {
if (l === "undefined") {
hasMissing = true;
continue;
}
if (seen.has(l)) continue;
seen.add(l);
universe.push(l);
}
if (embeddingsLabels) {
const extras = new Set<string>();
for (const raw of embeddingsLabels) {
const l = normalizeLabel(raw);
if (l === "undefined") {
hasMissing = true;
continue;
}
if (!seen.has(l)) extras.add(l);
}
if (extras.size > 0) {
const extraSorted = Array.from(extras).sort((a, b) => a.localeCompare(b));
for (const l of extraSorted) {
seen.add(l);
universe.push(l);
}
}
}
if (hasMissing) universe.push("undefined");
return universe;
}
export function buildLabelsInfo(params: {
datasetLabels: string[];
embeddings: EmbeddingsData | null;
distinctColoringDisabled: boolean;
labelFilter?: string | null;
unselectedAlpha?: number;
}): ScatterLabelsInfo | null {
const {
datasetLabels,
embeddings,
distinctColoringDisabled,
labelFilter = null,
unselectedAlpha = UNSELECTED_LABEL_ALPHA,
} = params;
if (!embeddings) return null;
const universe = buildLabelUniverse(datasetLabels, embeddings.labels);
// Guard: hyper-scatter categories are Uint16.
if (universe.length > 65535) {
console.warn(
`Too many labels (${universe.length}) for uint16 categories; collapsing to a single color.`
);
return {
uniqueLabels: ["undefined"],
categories: new Uint16Array(embeddings.labels.length),
palette: [FALLBACK_LABEL_COLOR],
};
}
const labelToCategory: Record<string, number> = {};
for (let i = 0; i < universe.length; i++) {
labelToCategory[universe[i]] = i;
}
const undefinedIndex = labelToCategory["undefined"] ?? 0;
const categories = new Uint16Array(embeddings.labels.length);
for (let i = 0; i < embeddings.labels.length; i++) {
const key = normalizeLabel(embeddings.labels[i]);
categories[i] = labelToCategory[key] ?? undefinedIndex;
}
let palette: string[];
if (distinctColoringDisabled) {
palette = universe.map((l) => (l === "undefined" ? MISSING_LABEL_COLOR : FALLBACK_LABEL_COLOR));
} else {
const colors = createLabelColorMap(universe);
palette = universe.map((l) => colors[l] ?? FALLBACK_LABEL_COLOR);
}
const filteredPalette = applyLabelFilterToPalette({
palette,
labels: universe,
labelFilter,
unselectedAlpha,
});
return { uniqueLabels: universe, categories, palette: filteredPalette };
}
export function buildLabelColorMap(params: {
labelsInfo: ScatterLabelsInfo | null;
labelUniverse: string[];
distinctColoringDisabled: boolean;
labelFilter?: string | null;
unselectedAlpha?: number;
}): Record<string, string> {
const {
labelsInfo,
labelUniverse,
distinctColoringDisabled,
labelFilter = null,
unselectedAlpha = UNSELECTED_LABEL_ALPHA,
} = params;
const map: Record<string, string> = {};
if (labelsInfo) {
for (let i = 0; i < labelsInfo.uniqueLabels.length; i++) {
map[labelsInfo.uniqueLabels[i]] = labelsInfo.palette[i] ?? FALLBACK_LABEL_COLOR;
}
return map;
}
if (labelUniverse.length === 0) return map;
if (distinctColoringDisabled) {
for (const label of labelUniverse) {
map[label] = label === "undefined" ? MISSING_LABEL_COLOR : FALLBACK_LABEL_COLOR;
}
return map;
}
const colors = createLabelColorMap(labelUniverse);
for (const label of labelUniverse) {
map[label] = colors[label] ?? FALLBACK_LABEL_COLOR;
}
if (!labelFilter || !labelUniverse.includes(labelFilter)) return map;
for (const label of labelUniverse) {
if (label !== labelFilter) {
map[label] = applyAlphaToHex(map[label], unselectedAlpha);
}
}
return map;
}
export function buildLegendLabels(params: {
labelUniverse: string[];
labelCounts: Map<string, number>;
query: string;
}): string[] {
const { labelUniverse, labelCounts, query } = params;
const all = labelUniverse.length > 0 ? [...labelUniverse] : Array.from(labelCounts.keys());
const q = query.trim().toLowerCase();
const filtered = q ? all.filter((l) => l.toLowerCase().includes(q)) : all;
const hasCounts = labelCounts.size > 0;
return filtered.sort((a, b) => {
if (a === "undefined" && b !== "undefined") return 1;
if (b === "undefined" && a !== "undefined") return -1;
if (hasCounts) {
const ca = labelCounts.get(a) ?? 0;
const cb = labelCounts.get(b) ?? 0;
if (cb !== ca) return cb - ca;
}
return a.localeCompare(b);
});
}