hf-model-viewer / src /components /CustomNodes.tsx
tfrere's picture
tfrere HF Staff
Deploy hf-model-viewer 2026-05-22T16:59:58Z
fc01079 verified
import { memo } from "react";
import { Handle, Position, type NodeProps } from "reactflow";
import { Box, Stack, Tooltip, Typography } from "@mui/material";
import type { IRNode } from "../types";
interface NodeData {
label: string;
sublabel: string;
color: string;
isCluster: boolean;
/** 0-1 share of the model's parameters. Drives the "heatmap" glow. */
intensity?: number;
/** If set and > 1, render an `×N` badge to indicate a folded ModuleList. */
repeatCount?: number;
/** Optional tooltip rendered on hover (richer than the node body). */
tooltip?: string;
/** Reference to the original IRNode for richer hover content. */
irNode?: IRNode;
/** True when this is a visual group container (subflow parent). */
isGroupContainer?: boolean;
}
/** Convert a 0-1 alpha into a 2-digit hex suffix for `#rrggbb` strings. */
function alphaHex(a: number): string {
const clamped = Math.max(0, Math.min(1, a));
return Math.round(clamped * 255)
.toString(16)
.padStart(2, "0");
}
function formatParams(n: number): string {
if (n === 0) return "0";
if (n < 1_000) return n.toString();
if (n < 1_000_000) return `${(n / 1_000).toFixed(2)}K`;
if (n < 1_000_000_000) return `${(n / 1_000_000).toFixed(2)}M`;
return `${(n / 1_000_000_000).toFixed(2)}B`;
}
function nodeTooltipContent(data: NodeData): React.ReactNode {
if (!data.irNode) {
return data.tooltip ?? `${data.label}\n${data.sublabel}`;
}
const n = data.irNode;
const desc = (n.attrs.description as string | undefined) ?? null;
const label = (n.attrs.label as string | undefined) ?? null;
const moduleP = (n.attrs.modulePath as string | undefined) ?? null;
const totalParams =
(n.attrs.totalParams as number | undefined) ??
n.weights.reduce((a, w) => a + w.numParams, 0);
const isSynthetic = n.attrs.synthetic === true;
return (
<Box sx={{ maxWidth: 320 }}>
<Typography variant="caption" sx={{ fontWeight: 700, display: "block" }}>
{label ?? data.label}
</Typography>
<Typography
variant="caption"
sx={{ display: "block", opacity: 0.7, mb: 0.5 }}
>
{n.opType}
{isSynthetic ? " · forward op" : ""}
</Typography>
{desc && (
<Typography variant="caption" sx={{ display: "block", mb: 0.5 }}>
{desc}
</Typography>
)}
{moduleP && (
<Typography
variant="caption"
sx={{
display: "block",
opacity: 0.8,
fontFamily: "monospace",
mb: 0.5,
}}
>
{moduleP}
</Typography>
)}
{totalParams > 0 && (
<Typography variant="caption" sx={{ display: "block", opacity: 0.85 }}>
{formatParams(totalParams)} parameters
</Typography>
)}
{n.weights.length > 0 && (
<Typography variant="caption" sx={{ display: "block", opacity: 0.7 }}>
{n.weights.length} tensor{n.weights.length > 1 ? "s" : ""}
</Typography>
)}
</Box>
);
}
function NodeBody({
data,
selected,
}: {
data: NodeData;
selected: boolean;
}) {
const intensity = data.intensity ?? 0;
const glowOpacity = Math.min(0.85, intensity);
const glowSize = Math.round(4 + intensity * 28);
const heatmapShadow =
intensity > 0.05 ? `, 0 0 ${glowSize}px ${data.color}${alphaHex(glowOpacity)}` : "";
const repeatCount = data.repeatCount ?? 0;
const hasRepeat = repeatCount > 1;
return (
<Box
sx={{
position: "relative",
width: "100%",
height: "100%",
display: "flex",
flexDirection: "column",
justifyContent: "center",
alignItems: "center",
textAlign: "center",
borderRadius: data.isCluster ? "12px" : "8px",
border: selected
? "2px solid #fff"
: hasRepeat
? "1.5px dashed rgba(255,255,255,0.55)"
: data.isCluster
? "1.5px solid rgba(255,255,255,0.35)"
: "1px solid rgba(255,255,255,0.18)",
background: data.isCluster
? `linear-gradient(160deg, ${data.color} 0%, ${data.color}cc 100%)`
: data.color,
boxShadow:
(data.isCluster
? "0 10px 24px rgba(0,0,0,0.4), inset 0 1px 0 rgba(255,255,255,0.18)"
: "0 6px 16px rgba(0,0,0,0.25)") + heatmapShadow,
color: "#0b1020",
px: data.isCluster ? 1 : 0.75,
py: data.isCluster ? 0.3 : 0.15,
overflow: "visible",
}}
>
{hasRepeat && (
<Box
sx={{
position: "absolute",
top: "50%",
left: "calc(100% + 6px)",
transform: "translateY(-50%)",
color: "rgba(255,255,255,0.9)",
fontSize: 13,
fontWeight: 700,
letterSpacing: 0.3,
lineHeight: 1,
pointerEvents: "none",
zIndex: 2,
whiteSpace: "nowrap",
textShadow: "0 1px 2px rgba(0,0,0,0.6)",
}}
>
×{repeatCount}
</Box>
)}
<Stack spacing={data.isCluster ? 0.1 : 0} alignItems="center" sx={{ width: "100%" }}>
<Typography
sx={{
fontSize: data.isCluster ? 11 : 10.5,
fontWeight: 700,
lineHeight: 1.1,
color: "#0b1020",
letterSpacing: 0.05,
maxWidth: "100%",
overflow: "hidden",
textOverflow: "ellipsis",
whiteSpace: "nowrap",
}}
>
{data.label}
</Typography>
{data.sublabel && (
<Typography
sx={{
fontSize: 9,
lineHeight: 1.1,
color: "rgba(11,16,32,0.7)",
fontWeight: 600,
maxWidth: "100%",
overflow: "hidden",
textOverflow: "ellipsis",
whiteSpace: "nowrap",
}}
>
{data.sublabel}
</Typography>
)}
</Stack>
</Box>
);
}
const handleStyle = {
background: "rgba(255,255,255,0.0)",
border: "0",
width: 1,
height: 1,
pointerEvents: "none" as const,
};
/**
* Side handles are used by skip-connections (residuals): an extra path is
* routed along the right edge of the column so it bypasses the chain of
* forward nodes instead of crossing them. They are visually muted (smaller,
* lower opacity) since the user shouldn't interact with them — they exist
* purely as endpoints for the long-range edges produced by `layout.ts`.
*/
const sideHandleStyle: React.CSSProperties = {
background: "rgba(255,255,255,0.0)",
border: "1px solid rgba(255,255,255,0.0)",
width: 3,
height: 3,
pointerEvents: "none",
};
function OpNodeImpl({ data, selected }: NodeProps<NodeData>) {
return (
<>
<Handle id="t" type="target" position={Position.Top} style={handleStyle} />
<Handle id="s-r" type="source" position={Position.Right} style={sideHandleStyle} />
<Handle id="t-r" type="target" position={Position.Right} style={sideHandleStyle} />
<Tooltip
title={nodeTooltipContent(data)}
placement="right"
arrow
enterDelay={250}
enterNextDelay={150}
>
<Box sx={{ width: "100%", height: "100%" }}>
<NodeBody data={data} selected={!!selected} />
</Box>
</Tooltip>
<Handle id="s" type="source" position={Position.Bottom} style={handleStyle} />
</>
);
}
function ClusterNodeImpl({ data, selected }: NodeProps<NodeData>) {
return (
<>
<Handle id="t" type="target" position={Position.Top} style={handleStyle} />
<Handle id="s-r" type="source" position={Position.Right} style={sideHandleStyle} />
<Handle id="t-r" type="target" position={Position.Right} style={sideHandleStyle} />
<Tooltip
title={nodeTooltipContent(data)}
placement="right"
arrow
enterDelay={250}
enterNextDelay={150}
>
<Box sx={{ width: "100%", height: "100%" }}>
<NodeBody data={data} selected={!!selected} />
</Box>
</Tooltip>
<Handle id="s" type="source" position={Position.Bottom} style={handleStyle} />
</>
);
}
/**
* Visual container ("subflow parent") that wraps a module's children. It is
* non-interactive (no handles, no tooltip) and intentionally translucent so
* the children inside are the visual focus, while the border + header band
* make the parent module read as a single hierarchical block.
*/
/**
* Visual container styled as an HTML <fieldset>: a rounded rectangle with
* a thin border, and a label tag laid OVER the top-left segment of that
* border so the title looks "carved" into it. This is the layout
* convention used by hfviewer.com and gives each module a clear visual
* boundary without occupying interior space.
*/
function GroupNodeImpl({ data }: NodeProps<NodeData>) {
const color = data.color ?? "#506280";
const repeatCount = data.repeatCount;
const hasRepeat = typeof repeatCount === "number" && repeatCount > 1;
return (
<Box
sx={{
width: "100%",
height: "100%",
position: "relative",
borderRadius: "8px",
border: `1px solid ${color}aa`,
background: "transparent",
pointerEvents: "none",
overflow: "visible",
}}
>
<Box
sx={{
position: "absolute",
top: -6,
left: 10,
px: 0.5,
background: "#0b1224",
lineHeight: 1,
}}
>
<Typography
sx={{
fontSize: 9,
fontWeight: 700,
letterSpacing: 0.4,
color: `${color}dd`,
whiteSpace: "nowrap",
pointerEvents: "none",
lineHeight: 1,
fontFamily:
"ui-monospace, SFMono-Regular, Menlo, Consolas, monospace",
}}
>
{data.label}
</Typography>
</Box>
{hasRepeat && (
<Box
sx={{
position: "absolute",
top: "50%",
left: "calc(100% + 8px)",
transform: "translateY(-50%)",
color: `${color}ee`,
fontSize: 16,
fontWeight: 700,
letterSpacing: 0.4,
lineHeight: 1,
pointerEvents: "none",
whiteSpace: "nowrap",
textShadow: "0 1px 2px rgba(0,0,0,0.6)",
}}
>
×{repeatCount}
</Box>
)}
</Box>
);
}
export const OpNode = memo(OpNodeImpl);
export const ClusterNode = memo(ClusterNodeImpl);
export const GroupNode = memo(GroupNodeImpl);
export const nodeTypes = {
op: OpNode,
cluster: ClusterNode,
group: GroupNode,
} as const;