import type React from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import type { EmbeddingsData } from "@/types"; import type { ScatterLabelsInfo } from "@/lib/labelLegend"; import type { Dataset, GeometryMode, Modifiers, Renderer } from "hyper-scatter"; type HyperScatterModule = typeof import("hyper-scatter"); const MAX_LASSO_VERTS = 512; function supportsWebGL2(): boolean { try { if (typeof document === "undefined") return false; const canvas = document.createElement("canvas"); return !!canvas.getContext("webgl2"); } catch { return false; } } function capInterleavedXY(points: ArrayLike, maxVerts: number): number[] { const n = Math.floor(points.length / 2); if (n <= maxVerts) return Array.from(points as ArrayLike); const out = new Array(maxVerts * 2); for (let i = 0; i < maxVerts; i++) { const src = Math.floor((i * n) / maxVerts); out[i * 2] = points[src * 2]; out[i * 2 + 1] = points[src * 2 + 1]; } return out; } interface UseHyperScatterArgs { embeddings: EmbeddingsData | null; labelsInfo: ScatterLabelsInfo | null; selectedIds: Set; hoveredId: string | null; setSelectedIds: (ids: Set, source?: "scatter" | "grid") => void; beginLassoSelection: (query: { layoutKey: string; polygon: number[] }) => void; setHoveredId: (id: string | null) => void; hoverEnabled?: boolean; } function toModifiers(e: { shiftKey: boolean; ctrlKey: boolean; altKey: boolean; metaKey: boolean }): Modifiers { return { shift: e.shiftKey, ctrl: e.ctrlKey, alt: e.altKey, meta: e.metaKey, }; } function clearOverlay(canvas: HTMLCanvasElement | null): void { if (!canvas) return; const ctx = canvas.getContext("2d"); if (!ctx) return; ctx.setTransform(1, 0, 0, 1, 0, 0); ctx.clearRect(0, 0, canvas.width, canvas.height); } function drawLassoOverlay(canvas: HTMLCanvasElement | null, points: number[]): void { if (!canvas) return; const ctx = canvas.getContext("2d"); if (!ctx) return; clearOverlay(canvas); if (points.length < 6) return; ctx.save(); ctx.lineWidth = 2; ctx.strokeStyle = "rgba(79,70,229,0.9)"; // indigo-ish ctx.fillStyle = "rgba(79,70,229,0.15)"; ctx.beginPath(); ctx.moveTo(points[0], points[1]); for (let i = 2; i < points.length; i += 2) { ctx.lineTo(points[i], points[i + 1]); } ctx.closePath(); ctx.fill(); ctx.stroke(); ctx.restore(); } export function useHyperScatter({ embeddings, labelsInfo, selectedIds, hoveredId, setSelectedIds, beginLassoSelection, setHoveredId, hoverEnabled = true, }: UseHyperScatterArgs) { const canvasRef = useRef(null); const overlayCanvasRef = useRef(null); const containerRef = useRef(null); const rendererRef = useRef(null); const [rendererError, setRendererError] = useState(null); const rafPendingRef = useRef(false); // Interaction state (refs to avoid rerender churn) const isPanningRef = useRef(false); const isLassoingRef = useRef(false); const pointerDownXRef = useRef(0); const pointerDownYRef = useRef(0); const lastPointerXRef = useRef(0); const lastPointerYRef = useRef(0); const lassoPointsRef = useRef([]); const persistentLassoRef = useRef(null); const hoveredIndexRef = useRef(-1); const idToIndex = useMemo(() => { if (!embeddings) return null; const m = new Map(); for (let i = 0; i < embeddings.ids.length; i++) { m.set(embeddings.ids[i], i); } return m; }, [embeddings]); const requestRender = useCallback(() => { if (rafPendingRef.current) return; rafPendingRef.current = true; requestAnimationFrame(() => { rafPendingRef.current = false; const renderer = rendererRef.current; if (!renderer) return; try { renderer.render(); } catch (err) { // Avoid an exception storm that would permanently prevent the UI from updating. console.error("hyper-scatter renderer.render() failed:", err); try { renderer.destroy(); } catch { // ignore } rendererRef.current = null; setRendererError( "This browser can't render the scatter plot (WebGL2 is required). Please use Chrome/Edge/Firefox." ); clearOverlay(overlayCanvasRef.current); return; } if (isLassoingRef.current) { drawLassoOverlay(overlayCanvasRef.current, lassoPointsRef.current); } }); }, []); const getCanvasPos = useCallback((e: { clientX: number; clientY: number }) => { const canvas = canvasRef.current; if (!canvas) return { x: 0, y: 0 }; const rect = canvas.getBoundingClientRect(); return { x: e.clientX - rect.left, y: e.clientY - rect.top, }; }, []); const redrawOverlay = useCallback(() => { if (!overlayCanvasRef.current) return; clearOverlay(overlayCanvasRef.current); const persistent = persistentLassoRef.current; if (persistent && persistent.length >= 6) { drawLassoOverlay(overlayCanvasRef.current, persistent); } }, []); const clearPersistentLasso = useCallback(() => { persistentLassoRef.current = null; clearOverlay(overlayCanvasRef.current); }, []); const stopInteraction = useCallback(() => { isPanningRef.current = false; isLassoingRef.current = false; lassoPointsRef.current = []; if (persistentLassoRef.current) { redrawOverlay(); return; } clearOverlay(overlayCanvasRef.current); }, [redrawOverlay]); // Initialize renderer when embeddings change. useEffect(() => { if (!embeddings || !labelsInfo) return; if (!canvasRef.current || !containerRef.current) return; let cancelled = false; const init = async () => { // Clear any previous renderer errors when we attempt to re-init. setRendererError(null); if (!supportsWebGL2()) { setRendererError( "This browser doesn't support WebGL2, so the scatter plot can't be displayed. Please use Chrome/Edge/Firefox." ); return; } try { const viz = (await import("hyper-scatter")) as HyperScatterModule; if (cancelled) return; const container = containerRef.current; const canvas = canvasRef.current; if (!container || !canvas) return; // Destroy existing renderer (if any) if (rendererRef.current) { rendererRef.current.destroy(); rendererRef.current = null; } const rect = container.getBoundingClientRect(); const width = Math.floor(rect.width); const height = Math.floor(rect.height); if (overlayCanvasRef.current) { overlayCanvasRef.current.width = Math.max(1, width); overlayCanvasRef.current.height = Math.max(1, height); overlayCanvasRef.current.style.width = `${width}px`; overlayCanvasRef.current.style.height = `${height}px`; redrawOverlay(); } // Use coords from embeddings response directly const coords = embeddings.coords; const positions = new Float32Array(coords.length * 2); for (let i = 0; i < coords.length; i++) { positions[i * 2] = coords[i][0]; positions[i * 2 + 1] = coords[i][1]; } const geometry = embeddings.geometry as GeometryMode; const dataset: Dataset = viz.createDataset(geometry, positions, labelsInfo.categories); const opts = { width, height, devicePixelRatio: window.devicePixelRatio, pointRadius: 4, colors: labelsInfo.palette, backgroundColor: "#161b22", // Match HyperView theme: --card is #161b22 }; const renderer: Renderer = geometry === "euclidean" ? new viz.EuclideanWebGLCandidate() : new viz.HyperbolicWebGLCandidate(); renderer.init(canvas, opts); renderer.setDataset(dataset); rendererRef.current = renderer; // Force a first render to surface WebGL2 context creation failures early. try { renderer.render(); } catch (err) { console.error("hyper-scatter initial render failed:", err); rendererRef.current = null; try { renderer.destroy(); } catch { // ignore } setRendererError( "This browser can't render the scatter plot (WebGL2 is required). Please use Chrome/Edge/Firefox." ); return; } hoveredIndexRef.current = -1; renderer.setHovered(-1); requestRender(); } catch (err) { console.error("Failed to initialize hyper-scatter renderer:", err); setRendererError( "Failed to initialize the scatter renderer in this browser. Please use Chrome/Edge/Firefox." ); } }; init(); return () => { cancelled = true; stopInteraction(); if (rendererRef.current) { rendererRef.current.destroy(); rendererRef.current = null; } }; }, [embeddings, labelsInfo, redrawOverlay, requestRender, stopInteraction]); // Store -> renderer sync useEffect(() => { const renderer = rendererRef.current; if (!renderer || !embeddings || !idToIndex) return; const indices = new Set(); for (const id of selectedIds) { const idx = idToIndex.get(id); if (typeof idx === "number") indices.add(idx); } renderer.setSelection(indices); if (!hoverEnabled) { renderer.setHovered(-1); hoveredIndexRef.current = -1; requestRender(); return; } const hoveredIdx = hoveredId ? (idToIndex.get(hoveredId) ?? -1) : -1; renderer.setHovered(hoveredIdx); hoveredIndexRef.current = hoveredIdx; requestRender(); }, [embeddings, hoveredId, hoverEnabled, idToIndex, requestRender, selectedIds]); // Resize handling useEffect(() => { const container = containerRef.current; if (!container) return; const resize = () => { const rect = container.getBoundingClientRect(); const width = Math.floor(rect.width); const height = Math.floor(rect.height); if (!(width > 0) || !(height > 0)) return; if (overlayCanvasRef.current) { overlayCanvasRef.current.width = Math.max(1, width); overlayCanvasRef.current.height = Math.max(1, height); overlayCanvasRef.current.style.width = `${width}px`; overlayCanvasRef.current.style.height = `${height}px`; redrawOverlay(); } const renderer = rendererRef.current; if (renderer) { renderer.resize(width, height); requestRender(); } }; resize(); const ro = new ResizeObserver(resize); ro.observe(container); return () => ro.disconnect(); }, [redrawOverlay, requestRender]); // Wheel zoom (native listener so we can set passive:false) useEffect(() => { const canvas = canvasRef.current; if (!canvas) return; const onWheel = (e: WheelEvent) => { const renderer = rendererRef.current; if (!renderer) return; e.preventDefault(); const pos = getCanvasPos(e); const delta = -e.deltaY / 100; renderer.zoom(pos.x, pos.y, delta, toModifiers(e)); requestRender(); }; canvas.addEventListener("wheel", onWheel, { passive: false }); return () => canvas.removeEventListener("wheel", onWheel); }, [getCanvasPos, requestRender]); // Pointer interactions const handlePointerDown = useCallback( (e: React.PointerEvent) => { const renderer = rendererRef.current; if (!renderer) return; // Left button only if (typeof e.button === "number" && e.button !== 0) return; const pos = getCanvasPos(e); pointerDownXRef.current = pos.x; pointerDownYRef.current = pos.y; lastPointerXRef.current = pos.x; lastPointerYRef.current = pos.y; if (persistentLassoRef.current) { clearPersistentLasso(); } // Shift-drag = lasso, otherwise pan. if (e.shiftKey) { isLassoingRef.current = true; isPanningRef.current = false; lassoPointsRef.current = [pos.x, pos.y]; drawLassoOverlay(overlayCanvasRef.current, lassoPointsRef.current); } else { isPanningRef.current = true; isLassoingRef.current = false; } try { e.currentTarget.setPointerCapture(e.pointerId); } catch { // ignore } e.preventDefault(); }, [clearPersistentLasso, getCanvasPos] ); const handlePointerMove = useCallback( (e: React.PointerEvent) => { const renderer = rendererRef.current; if (!renderer) return; const pos = getCanvasPos(e); if (isPanningRef.current) { const dx = pos.x - lastPointerXRef.current; const dy = pos.y - lastPointerYRef.current; lastPointerXRef.current = pos.x; lastPointerYRef.current = pos.y; renderer.pan(dx, dy, toModifiers(e)); requestRender(); return; } if (isLassoingRef.current) { const pts = lassoPointsRef.current; const lastX = pts[pts.length - 2] ?? pos.x; const lastY = pts[pts.length - 1] ?? pos.y; const ddx = pos.x - lastX; const ddy = pos.y - lastY; const distSq = ddx * ddx + ddy * ddy; // Sample at ~2px spacing if (distSq >= 4) { pts.push(pos.x, pos.y); drawLassoOverlay(overlayCanvasRef.current, pts); } return; } if (!hoverEnabled) { if (hoveredIndexRef.current !== -1) { hoveredIndexRef.current = -1; renderer.setHovered(-1); requestRender(); } return; } // Hover const hit = renderer.hitTest(pos.x, pos.y); const nextIndex = hit ? hit.index : -1; if (nextIndex === hoveredIndexRef.current) return; hoveredIndexRef.current = nextIndex; renderer.setHovered(nextIndex); if (!embeddings) return; if (nextIndex >= 0 && nextIndex < embeddings.ids.length) { setHoveredId(embeddings.ids[nextIndex]); } else { setHoveredId(null); } requestRender(); }, [embeddings, getCanvasPos, hoverEnabled, requestRender, setHoveredId] ); const handlePointerUp = useCallback( async (e: React.PointerEvent) => { const renderer = rendererRef.current; if (!renderer || !embeddings) { stopInteraction(); return; } if (isLassoingRef.current) { const pts = lassoPointsRef.current.slice(); persistentLassoRef.current = pts.length >= 6 ? pts : null; stopInteraction(); redrawOverlay(); if (pts.length >= 6) { try { const polyline = new Float32Array(pts); const result = renderer.lassoSelect(polyline); // Enter server-driven lasso mode by sending a data-space polygon. // Backend selection runs in the same coordinate system returned by /api/embeddings. const dataCoords = result.geometry?.coords; if (!dataCoords || dataCoords.length < 6) return; // Clear any existing manual selection highlights immediately. renderer.setSelection(new Set()); // Cap vertex count to keep request payload + backend runtime bounded. const polygon = capInterleavedXY(dataCoords, MAX_LASSO_VERTS); if (polygon.length < 6) return; beginLassoSelection({ layoutKey: embeddings.layout_key, polygon }); } catch (err) { console.error("Lasso selection failed:", err); } } requestRender(); return; } // Click-to-select (scatter -> image grid) // Only treat as a click if the pointer didn't move much (otherwise it's a pan). const pos = getCanvasPos(e); const dx = pos.x - pointerDownXRef.current; const dy = pos.y - pointerDownYRef.current; const CLICK_MAX_DIST_SQ = 36; // ~6px const isClick = dx * dx + dy * dy <= CLICK_MAX_DIST_SQ; if (isClick) { const hit = renderer.hitTest(pos.x, pos.y); const idx = hit ? hit.index : -1; if (idx >= 0 && idx < embeddings.ids.length) { const id = embeddings.ids[idx]; if (e.metaKey || e.ctrlKey) { const next = new Set(selectedIds); if (next.has(id)) next.delete(id); else next.add(id); setSelectedIds(next, "scatter"); } else { setSelectedIds(new Set([id]), "scatter"); } } } stopInteraction(); requestRender(); }, [ beginLassoSelection, embeddings, getCanvasPos, redrawOverlay, requestRender, selectedIds, setSelectedIds, stopInteraction, ] ); const handlePointerLeave = useCallback( (_e: React.PointerEvent) => { const renderer = rendererRef.current; if (renderer) { hoveredIndexRef.current = -1; setHoveredId(null); renderer.setHovered(-1); requestRender(); } stopInteraction(); }, [requestRender, setHoveredId, stopInteraction] ); const handleDoubleClick = useCallback( (_e: React.MouseEvent) => { const renderer = rendererRef.current; if (!renderer) return; clearPersistentLasso(); stopInteraction(); renderer.setSelection(new Set()); setSelectedIds(new Set(), "scatter"); requestRender(); }, [clearPersistentLasso, requestRender, setSelectedIds, stopInteraction] ); return { canvasRef, overlayCanvasRef, containerRef, handlePointerDown, handlePointerMove, handlePointerUp, handlePointerLeave, handleDoubleClick, rendererError, }; }