import { memo } from "react"; import { Handle, Position, type NodeProps } from "reactflow"; import { Box, Stack, Tooltip, Typography } from "@mui/material"; import type { IRNode } from "../types"; interface NodeData { label: string; sublabel: string; color: string; isCluster: boolean; /** 0-1 share of the model's parameters. Drives the "heatmap" glow. */ intensity?: number; /** If set and > 1, render an `×N` badge to indicate a folded ModuleList. */ repeatCount?: number; /** Optional tooltip rendered on hover (richer than the node body). */ tooltip?: string; /** Reference to the original IRNode for richer hover content. */ irNode?: IRNode; /** True when this is a visual group container (subflow parent). */ isGroupContainer?: boolean; } /** Convert a 0-1 alpha into a 2-digit hex suffix for `#rrggbb` strings. */ function alphaHex(a: number): string { const clamped = Math.max(0, Math.min(1, a)); return Math.round(clamped * 255) .toString(16) .padStart(2, "0"); } function formatParams(n: number): string { if (n === 0) return "0"; if (n < 1_000) return n.toString(); if (n < 1_000_000) return `${(n / 1_000).toFixed(2)}K`; if (n < 1_000_000_000) return `${(n / 1_000_000).toFixed(2)}M`; return `${(n / 1_000_000_000).toFixed(2)}B`; } function nodeTooltipContent(data: NodeData): React.ReactNode { if (!data.irNode) { return data.tooltip ?? `${data.label}\n${data.sublabel}`; } const n = data.irNode; const desc = (n.attrs.description as string | undefined) ?? null; const label = (n.attrs.label as string | undefined) ?? null; const moduleP = (n.attrs.modulePath as string | undefined) ?? null; const totalParams = (n.attrs.totalParams as number | undefined) ?? n.weights.reduce((a, w) => a + w.numParams, 0); const isSynthetic = n.attrs.synthetic === true; return ( {label ?? data.label} {n.opType} {isSynthetic ? " · forward op" : ""} {desc && ( {desc} )} {moduleP && ( {moduleP} )} {totalParams > 0 && ( {formatParams(totalParams)} parameters )} {n.weights.length > 0 && ( {n.weights.length} tensor{n.weights.length > 1 ? "s" : ""} )} ); } function NodeBody({ data, selected, }: { data: NodeData; selected: boolean; }) { const intensity = data.intensity ?? 0; const glowOpacity = Math.min(0.85, intensity); const glowSize = Math.round(4 + intensity * 28); const heatmapShadow = intensity > 0.05 ? `, 0 0 ${glowSize}px ${data.color}${alphaHex(glowOpacity)}` : ""; const repeatCount = data.repeatCount ?? 0; const hasRepeat = repeatCount > 1; return ( {hasRepeat && ( ×{repeatCount} )} {data.label} {data.sublabel && ( {data.sublabel} )} ); } const handleStyle = { background: "rgba(255,255,255,0.0)", border: "0", width: 1, height: 1, pointerEvents: "none" as const, }; /** * Side handles are used by skip-connections (residuals): an extra path is * routed along the right edge of the column so it bypasses the chain of * forward nodes instead of crossing them. They are visually muted (smaller, * lower opacity) since the user shouldn't interact with them — they exist * purely as endpoints for the long-range edges produced by `layout.ts`. */ const sideHandleStyle: React.CSSProperties = { background: "rgba(255,255,255,0.0)", border: "1px solid rgba(255,255,255,0.0)", width: 3, height: 3, pointerEvents: "none", }; function OpNodeImpl({ data, selected }: NodeProps) { return ( <> ); } function ClusterNodeImpl({ data, selected }: NodeProps) { return ( <> ); } /** * Visual container ("subflow parent") that wraps a module's children. It is * non-interactive (no handles, no tooltip) and intentionally translucent so * the children inside are the visual focus, while the border + header band * make the parent module read as a single hierarchical block. */ /** * Visual container styled as an HTML
: a rounded rectangle with * a thin border, and a label tag laid OVER the top-left segment of that * border so the title looks "carved" into it. This is the layout * convention used by hfviewer.com and gives each module a clear visual * boundary without occupying interior space. */ function GroupNodeImpl({ data }: NodeProps) { const color = data.color ?? "#506280"; const repeatCount = data.repeatCount; const hasRepeat = typeof repeatCount === "number" && repeatCount > 1; return ( {data.label} {hasRepeat && ( ×{repeatCount} )} ); } export const OpNode = memo(OpNodeImpl); export const ClusterNode = memo(ClusterNodeImpl); export const GroupNode = memo(GroupNodeImpl); export const nodeTypes = { op: OpNode, cluster: ClusterNode, group: GroupNode, } as const;