HyperView / frontend /src /components /useHyperScatter.ts
morozovdd's picture
feat: add HyperView app for space
23680f2
import type React from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { EmbeddingsData } from "@/types";
import type { ScatterLabelsInfo } from "@/lib/labelLegend";
import type { Dataset, GeometryMode, Modifiers, Renderer } from "hyper-scatter";
type HyperScatterModule = typeof import("hyper-scatter");
const MAX_LASSO_VERTS = 512;
function supportsWebGL2(): boolean {
try {
if (typeof document === "undefined") return false;
const canvas = document.createElement("canvas");
return !!canvas.getContext("webgl2");
} catch {
return false;
}
}
function capInterleavedXY(points: ArrayLike<number>, maxVerts: number): number[] {
const n = Math.floor(points.length / 2);
if (n <= maxVerts) return Array.from(points as ArrayLike<number>);
const out = new Array<number>(maxVerts * 2);
for (let i = 0; i < maxVerts; i++) {
const src = Math.floor((i * n) / maxVerts);
out[i * 2] = points[src * 2];
out[i * 2 + 1] = points[src * 2 + 1];
}
return out;
}
interface UseHyperScatterArgs {
embeddings: EmbeddingsData | null;
labelsInfo: ScatterLabelsInfo | null;
selectedIds: Set<string>;
hoveredId: string | null;
setSelectedIds: (ids: Set<string>, source?: "scatter" | "grid") => void;
beginLassoSelection: (query: { layoutKey: string; polygon: number[] }) => void;
setHoveredId: (id: string | null) => void;
hoverEnabled?: boolean;
}
function toModifiers(e: { shiftKey: boolean; ctrlKey: boolean; altKey: boolean; metaKey: boolean }): Modifiers {
return {
shift: e.shiftKey,
ctrl: e.ctrlKey,
alt: e.altKey,
meta: e.metaKey,
};
}
function clearOverlay(canvas: HTMLCanvasElement | null): void {
if (!canvas) return;
const ctx = canvas.getContext("2d");
if (!ctx) return;
ctx.setTransform(1, 0, 0, 1, 0, 0);
ctx.clearRect(0, 0, canvas.width, canvas.height);
}
function drawLassoOverlay(canvas: HTMLCanvasElement | null, points: number[]): void {
if (!canvas) return;
const ctx = canvas.getContext("2d");
if (!ctx) return;
clearOverlay(canvas);
if (points.length < 6) return;
ctx.save();
ctx.lineWidth = 2;
ctx.strokeStyle = "rgba(79,70,229,0.9)"; // indigo-ish
ctx.fillStyle = "rgba(79,70,229,0.15)";
ctx.beginPath();
ctx.moveTo(points[0], points[1]);
for (let i = 2; i < points.length; i += 2) {
ctx.lineTo(points[i], points[i + 1]);
}
ctx.closePath();
ctx.fill();
ctx.stroke();
ctx.restore();
}
export function useHyperScatter({
embeddings,
labelsInfo,
selectedIds,
hoveredId,
setSelectedIds,
beginLassoSelection,
setHoveredId,
hoverEnabled = true,
}: UseHyperScatterArgs) {
const canvasRef = useRef<HTMLCanvasElement>(null);
const overlayCanvasRef = useRef<HTMLCanvasElement>(null);
const containerRef = useRef<HTMLDivElement>(null);
const rendererRef = useRef<Renderer | null>(null);
const [rendererError, setRendererError] = useState<string | null>(null);
const rafPendingRef = useRef(false);
// Interaction state (refs to avoid rerender churn)
const isPanningRef = useRef(false);
const isLassoingRef = useRef(false);
const pointerDownXRef = useRef(0);
const pointerDownYRef = useRef(0);
const lastPointerXRef = useRef(0);
const lastPointerYRef = useRef(0);
const lassoPointsRef = useRef<number[]>([]);
const persistentLassoRef = useRef<number[] | null>(null);
const hoveredIndexRef = useRef<number>(-1);
const idToIndex = useMemo(() => {
if (!embeddings) return null;
const m = new Map<string, number>();
for (let i = 0; i < embeddings.ids.length; i++) {
m.set(embeddings.ids[i], i);
}
return m;
}, [embeddings]);
const requestRender = useCallback(() => {
if (rafPendingRef.current) return;
rafPendingRef.current = true;
requestAnimationFrame(() => {
rafPendingRef.current = false;
const renderer = rendererRef.current;
if (!renderer) return;
try {
renderer.render();
} catch (err) {
// Avoid an exception storm that would permanently prevent the UI from updating.
console.error("hyper-scatter renderer.render() failed:", err);
try {
renderer.destroy();
} catch {
// ignore
}
rendererRef.current = null;
setRendererError(
"This browser can't render the scatter plot (WebGL2 is required). Please use Chrome/Edge/Firefox."
);
clearOverlay(overlayCanvasRef.current);
return;
}
if (isLassoingRef.current) {
drawLassoOverlay(overlayCanvasRef.current, lassoPointsRef.current);
}
});
}, []);
const getCanvasPos = useCallback((e: { clientX: number; clientY: number }) => {
const canvas = canvasRef.current;
if (!canvas) return { x: 0, y: 0 };
const rect = canvas.getBoundingClientRect();
return {
x: e.clientX - rect.left,
y: e.clientY - rect.top,
};
}, []);
const redrawOverlay = useCallback(() => {
if (!overlayCanvasRef.current) return;
clearOverlay(overlayCanvasRef.current);
const persistent = persistentLassoRef.current;
if (persistent && persistent.length >= 6) {
drawLassoOverlay(overlayCanvasRef.current, persistent);
}
}, []);
const clearPersistentLasso = useCallback(() => {
persistentLassoRef.current = null;
clearOverlay(overlayCanvasRef.current);
}, []);
const stopInteraction = useCallback(() => {
isPanningRef.current = false;
isLassoingRef.current = false;
lassoPointsRef.current = [];
if (persistentLassoRef.current) {
redrawOverlay();
return;
}
clearOverlay(overlayCanvasRef.current);
}, [redrawOverlay]);
// Initialize renderer when embeddings change.
useEffect(() => {
if (!embeddings || !labelsInfo) return;
if (!canvasRef.current || !containerRef.current) return;
let cancelled = false;
const init = async () => {
// Clear any previous renderer errors when we attempt to re-init.
setRendererError(null);
if (!supportsWebGL2()) {
setRendererError(
"This browser doesn't support WebGL2, so the scatter plot can't be displayed. Please use Chrome/Edge/Firefox."
);
return;
}
try {
const viz = (await import("hyper-scatter")) as HyperScatterModule;
if (cancelled) return;
const container = containerRef.current;
const canvas = canvasRef.current;
if (!container || !canvas) return;
// Destroy existing renderer (if any)
if (rendererRef.current) {
rendererRef.current.destroy();
rendererRef.current = null;
}
const rect = container.getBoundingClientRect();
const width = Math.floor(rect.width);
const height = Math.floor(rect.height);
if (overlayCanvasRef.current) {
overlayCanvasRef.current.width = Math.max(1, width);
overlayCanvasRef.current.height = Math.max(1, height);
overlayCanvasRef.current.style.width = `${width}px`;
overlayCanvasRef.current.style.height = `${height}px`;
redrawOverlay();
}
// Use coords from embeddings response directly
const coords = embeddings.coords;
const positions = new Float32Array(coords.length * 2);
for (let i = 0; i < coords.length; i++) {
positions[i * 2] = coords[i][0];
positions[i * 2 + 1] = coords[i][1];
}
const geometry = embeddings.geometry as GeometryMode;
const dataset: Dataset = viz.createDataset(geometry, positions, labelsInfo.categories);
const opts = {
width,
height,
devicePixelRatio: window.devicePixelRatio,
pointRadius: 4,
colors: labelsInfo.palette,
backgroundColor: "#161b22", // Match HyperView theme: --card is #161b22
};
const renderer: Renderer =
geometry === "euclidean" ? new viz.EuclideanWebGLCandidate() : new viz.HyperbolicWebGLCandidate();
renderer.init(canvas, opts);
renderer.setDataset(dataset);
rendererRef.current = renderer;
// Force a first render to surface WebGL2 context creation failures early.
try {
renderer.render();
} catch (err) {
console.error("hyper-scatter initial render failed:", err);
rendererRef.current = null;
try {
renderer.destroy();
} catch {
// ignore
}
setRendererError(
"This browser can't render the scatter plot (WebGL2 is required). Please use Chrome/Edge/Firefox."
);
return;
}
hoveredIndexRef.current = -1;
renderer.setHovered(-1);
requestRender();
} catch (err) {
console.error("Failed to initialize hyper-scatter renderer:", err);
setRendererError(
"Failed to initialize the scatter renderer in this browser. Please use Chrome/Edge/Firefox."
);
}
};
init();
return () => {
cancelled = true;
stopInteraction();
if (rendererRef.current) {
rendererRef.current.destroy();
rendererRef.current = null;
}
};
}, [embeddings, labelsInfo, redrawOverlay, requestRender, stopInteraction]);
// Store -> renderer sync
useEffect(() => {
const renderer = rendererRef.current;
if (!renderer || !embeddings || !idToIndex) return;
const indices = new Set<number>();
for (const id of selectedIds) {
const idx = idToIndex.get(id);
if (typeof idx === "number") indices.add(idx);
}
renderer.setSelection(indices);
if (!hoverEnabled) {
renderer.setHovered(-1);
hoveredIndexRef.current = -1;
requestRender();
return;
}
const hoveredIdx = hoveredId ? (idToIndex.get(hoveredId) ?? -1) : -1;
renderer.setHovered(hoveredIdx);
hoveredIndexRef.current = hoveredIdx;
requestRender();
}, [embeddings, hoveredId, hoverEnabled, idToIndex, requestRender, selectedIds]);
// Resize handling
useEffect(() => {
const container = containerRef.current;
if (!container) return;
const resize = () => {
const rect = container.getBoundingClientRect();
const width = Math.floor(rect.width);
const height = Math.floor(rect.height);
if (!(width > 0) || !(height > 0)) return;
if (overlayCanvasRef.current) {
overlayCanvasRef.current.width = Math.max(1, width);
overlayCanvasRef.current.height = Math.max(1, height);
overlayCanvasRef.current.style.width = `${width}px`;
overlayCanvasRef.current.style.height = `${height}px`;
redrawOverlay();
}
const renderer = rendererRef.current;
if (renderer) {
renderer.resize(width, height);
requestRender();
}
};
resize();
const ro = new ResizeObserver(resize);
ro.observe(container);
return () => ro.disconnect();
}, [redrawOverlay, requestRender]);
// Wheel zoom (native listener so we can set passive:false)
useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) return;
const onWheel = (e: WheelEvent) => {
const renderer = rendererRef.current;
if (!renderer) return;
e.preventDefault();
const pos = getCanvasPos(e);
const delta = -e.deltaY / 100;
renderer.zoom(pos.x, pos.y, delta, toModifiers(e));
requestRender();
};
canvas.addEventListener("wheel", onWheel, { passive: false });
return () => canvas.removeEventListener("wheel", onWheel);
}, [getCanvasPos, requestRender]);
// Pointer interactions
const handlePointerDown = useCallback(
(e: React.PointerEvent<HTMLCanvasElement>) => {
const renderer = rendererRef.current;
if (!renderer) return;
// Left button only
if (typeof e.button === "number" && e.button !== 0) return;
const pos = getCanvasPos(e);
pointerDownXRef.current = pos.x;
pointerDownYRef.current = pos.y;
lastPointerXRef.current = pos.x;
lastPointerYRef.current = pos.y;
if (persistentLassoRef.current) {
clearPersistentLasso();
}
// Shift-drag = lasso, otherwise pan.
if (e.shiftKey) {
isLassoingRef.current = true;
isPanningRef.current = false;
lassoPointsRef.current = [pos.x, pos.y];
drawLassoOverlay(overlayCanvasRef.current, lassoPointsRef.current);
} else {
isPanningRef.current = true;
isLassoingRef.current = false;
}
try {
e.currentTarget.setPointerCapture(e.pointerId);
} catch {
// ignore
}
e.preventDefault();
},
[clearPersistentLasso, getCanvasPos]
);
const handlePointerMove = useCallback(
(e: React.PointerEvent<HTMLCanvasElement>) => {
const renderer = rendererRef.current;
if (!renderer) return;
const pos = getCanvasPos(e);
if (isPanningRef.current) {
const dx = pos.x - lastPointerXRef.current;
const dy = pos.y - lastPointerYRef.current;
lastPointerXRef.current = pos.x;
lastPointerYRef.current = pos.y;
renderer.pan(dx, dy, toModifiers(e));
requestRender();
return;
}
if (isLassoingRef.current) {
const pts = lassoPointsRef.current;
const lastX = pts[pts.length - 2] ?? pos.x;
const lastY = pts[pts.length - 1] ?? pos.y;
const ddx = pos.x - lastX;
const ddy = pos.y - lastY;
const distSq = ddx * ddx + ddy * ddy;
// Sample at ~2px spacing
if (distSq >= 4) {
pts.push(pos.x, pos.y);
drawLassoOverlay(overlayCanvasRef.current, pts);
}
return;
}
if (!hoverEnabled) {
if (hoveredIndexRef.current !== -1) {
hoveredIndexRef.current = -1;
renderer.setHovered(-1);
requestRender();
}
return;
}
// Hover
const hit = renderer.hitTest(pos.x, pos.y);
const nextIndex = hit ? hit.index : -1;
if (nextIndex === hoveredIndexRef.current) return;
hoveredIndexRef.current = nextIndex;
renderer.setHovered(nextIndex);
if (!embeddings) return;
if (nextIndex >= 0 && nextIndex < embeddings.ids.length) {
setHoveredId(embeddings.ids[nextIndex]);
} else {
setHoveredId(null);
}
requestRender();
},
[embeddings, getCanvasPos, hoverEnabled, requestRender, setHoveredId]
);
const handlePointerUp = useCallback(
async (e: React.PointerEvent<HTMLCanvasElement>) => {
const renderer = rendererRef.current;
if (!renderer || !embeddings) {
stopInteraction();
return;
}
if (isLassoingRef.current) {
const pts = lassoPointsRef.current.slice();
persistentLassoRef.current = pts.length >= 6 ? pts : null;
stopInteraction();
redrawOverlay();
if (pts.length >= 6) {
try {
const polyline = new Float32Array(pts);
const result = renderer.lassoSelect(polyline);
// Enter server-driven lasso mode by sending a data-space polygon.
// Backend selection runs in the same coordinate system returned by /api/embeddings.
const dataCoords = result.geometry?.coords;
if (!dataCoords || dataCoords.length < 6) return;
// Clear any existing manual selection highlights immediately.
renderer.setSelection(new Set());
// Cap vertex count to keep request payload + backend runtime bounded.
const polygon = capInterleavedXY(dataCoords, MAX_LASSO_VERTS);
if (polygon.length < 6) return;
beginLassoSelection({ layoutKey: embeddings.layout_key, polygon });
} catch (err) {
console.error("Lasso selection failed:", err);
}
}
requestRender();
return;
}
// Click-to-select (scatter -> image grid)
// Only treat as a click if the pointer didn't move much (otherwise it's a pan).
const pos = getCanvasPos(e);
const dx = pos.x - pointerDownXRef.current;
const dy = pos.y - pointerDownYRef.current;
const CLICK_MAX_DIST_SQ = 36; // ~6px
const isClick = dx * dx + dy * dy <= CLICK_MAX_DIST_SQ;
if (isClick) {
const hit = renderer.hitTest(pos.x, pos.y);
const idx = hit ? hit.index : -1;
if (idx >= 0 && idx < embeddings.ids.length) {
const id = embeddings.ids[idx];
if (e.metaKey || e.ctrlKey) {
const next = new Set(selectedIds);
if (next.has(id)) next.delete(id);
else next.add(id);
setSelectedIds(next, "scatter");
} else {
setSelectedIds(new Set([id]), "scatter");
}
}
}
stopInteraction();
requestRender();
},
[
beginLassoSelection,
embeddings,
getCanvasPos,
redrawOverlay,
requestRender,
selectedIds,
setSelectedIds,
stopInteraction,
]
);
const handlePointerLeave = useCallback(
(_e: React.PointerEvent<HTMLCanvasElement>) => {
const renderer = rendererRef.current;
if (renderer) {
hoveredIndexRef.current = -1;
setHoveredId(null);
renderer.setHovered(-1);
requestRender();
}
stopInteraction();
},
[requestRender, setHoveredId, stopInteraction]
);
const handleDoubleClick = useCallback(
(_e: React.MouseEvent<HTMLCanvasElement>) => {
const renderer = rendererRef.current;
if (!renderer) return;
clearPersistentLasso();
stopInteraction();
renderer.setSelection(new Set());
setSelectedIds(new Set<string>(), "scatter");
requestRender();
},
[clearPersistentLasso, requestRender, setSelectedIds, stopInteraction]
);
return {
canvasRef,
overlayCanvasRef,
containerRef,
handlePointerDown,
handlePointerMove,
handlePointerUp,
handlePointerLeave,
handleDoubleClick,
rendererError,
};
}