Spaces:
Sleeping
Sleeping
| import { memo } from "react"; | |
| import { Handle, Position, type NodeProps } from "reactflow"; | |
| import { Box, Stack, Tooltip, Typography } from "@mui/material"; | |
| import type { IRNode } from "../types"; | |
| interface NodeData { | |
| label: string; | |
| sublabel: string; | |
| color: string; | |
| isCluster: boolean; | |
| /** 0-1 share of the model's parameters. Drives the "heatmap" glow. */ | |
| intensity?: number; | |
| /** If set and > 1, render an `×N` badge to indicate a folded ModuleList. */ | |
| repeatCount?: number; | |
| /** Optional tooltip rendered on hover (richer than the node body). */ | |
| tooltip?: string; | |
| /** Reference to the original IRNode for richer hover content. */ | |
| irNode?: IRNode; | |
| /** True when this is a visual group container (subflow parent). */ | |
| isGroupContainer?: boolean; | |
| } | |
| /** Convert a 0-1 alpha into a 2-digit hex suffix for `#rrggbb` strings. */ | |
| function alphaHex(a: number): string { | |
| const clamped = Math.max(0, Math.min(1, a)); | |
| return Math.round(clamped * 255) | |
| .toString(16) | |
| .padStart(2, "0"); | |
| } | |
| function formatParams(n: number): string { | |
| if (n === 0) return "0"; | |
| if (n < 1_000) return n.toString(); | |
| if (n < 1_000_000) return `${(n / 1_000).toFixed(2)}K`; | |
| if (n < 1_000_000_000) return `${(n / 1_000_000).toFixed(2)}M`; | |
| return `${(n / 1_000_000_000).toFixed(2)}B`; | |
| } | |
| function nodeTooltipContent(data: NodeData): React.ReactNode { | |
| if (!data.irNode) { | |
| return data.tooltip ?? `${data.label}\n${data.sublabel}`; | |
| } | |
| const n = data.irNode; | |
| const desc = (n.attrs.description as string | undefined) ?? null; | |
| const label = (n.attrs.label as string | undefined) ?? null; | |
| const moduleP = (n.attrs.modulePath as string | undefined) ?? null; | |
| const totalParams = | |
| (n.attrs.totalParams as number | undefined) ?? | |
| n.weights.reduce((a, w) => a + w.numParams, 0); | |
| const isSynthetic = n.attrs.synthetic === true; | |
| return ( | |
| <Box sx={{ maxWidth: 320 }}> | |
| <Typography variant="caption" sx={{ fontWeight: 700, display: "block" }}> | |
| {label ?? data.label} | |
| </Typography> | |
| <Typography | |
| variant="caption" | |
| sx={{ display: "block", opacity: 0.7, mb: 0.5 }} | |
| > | |
| {n.opType} | |
| {isSynthetic ? " · forward op" : ""} | |
| </Typography> | |
| {desc && ( | |
| <Typography variant="caption" sx={{ display: "block", mb: 0.5 }}> | |
| {desc} | |
| </Typography> | |
| )} | |
| {moduleP && ( | |
| <Typography | |
| variant="caption" | |
| sx={{ | |
| display: "block", | |
| opacity: 0.8, | |
| fontFamily: "monospace", | |
| mb: 0.5, | |
| }} | |
| > | |
| {moduleP} | |
| </Typography> | |
| )} | |
| {totalParams > 0 && ( | |
| <Typography variant="caption" sx={{ display: "block", opacity: 0.85 }}> | |
| {formatParams(totalParams)} parameters | |
| </Typography> | |
| )} | |
| {n.weights.length > 0 && ( | |
| <Typography variant="caption" sx={{ display: "block", opacity: 0.7 }}> | |
| {n.weights.length} tensor{n.weights.length > 1 ? "s" : ""} | |
| </Typography> | |
| )} | |
| </Box> | |
| ); | |
| } | |
| function NodeBody({ | |
| data, | |
| selected, | |
| }: { | |
| data: NodeData; | |
| selected: boolean; | |
| }) { | |
| const intensity = data.intensity ?? 0; | |
| const glowOpacity = Math.min(0.85, intensity); | |
| const glowSize = Math.round(4 + intensity * 28); | |
| const heatmapShadow = | |
| intensity > 0.05 ? `, 0 0 ${glowSize}px ${data.color}${alphaHex(glowOpacity)}` : ""; | |
| const repeatCount = data.repeatCount ?? 0; | |
| const hasRepeat = repeatCount > 1; | |
| return ( | |
| <Box | |
| sx={{ | |
| position: "relative", | |
| width: "100%", | |
| height: "100%", | |
| display: "flex", | |
| flexDirection: "column", | |
| justifyContent: "center", | |
| alignItems: "center", | |
| textAlign: "center", | |
| borderRadius: data.isCluster ? "12px" : "8px", | |
| border: selected | |
| ? "2px solid #fff" | |
| : hasRepeat | |
| ? "1.5px dashed rgba(255,255,255,0.55)" | |
| : data.isCluster | |
| ? "1.5px solid rgba(255,255,255,0.35)" | |
| : "1px solid rgba(255,255,255,0.18)", | |
| background: data.isCluster | |
| ? `linear-gradient(160deg, ${data.color} 0%, ${data.color}cc 100%)` | |
| : data.color, | |
| boxShadow: | |
| (data.isCluster | |
| ? "0 10px 24px rgba(0,0,0,0.4), inset 0 1px 0 rgba(255,255,255,0.18)" | |
| : "0 6px 16px rgba(0,0,0,0.25)") + heatmapShadow, | |
| color: "#0b1020", | |
| px: data.isCluster ? 1 : 0.75, | |
| py: data.isCluster ? 0.3 : 0.15, | |
| overflow: "visible", | |
| }} | |
| > | |
| {hasRepeat && ( | |
| <Box | |
| sx={{ | |
| position: "absolute", | |
| top: "50%", | |
| left: "calc(100% + 6px)", | |
| transform: "translateY(-50%)", | |
| color: "rgba(255,255,255,0.9)", | |
| fontSize: 13, | |
| fontWeight: 700, | |
| letterSpacing: 0.3, | |
| lineHeight: 1, | |
| pointerEvents: "none", | |
| zIndex: 2, | |
| whiteSpace: "nowrap", | |
| textShadow: "0 1px 2px rgba(0,0,0,0.6)", | |
| }} | |
| > | |
| ×{repeatCount} | |
| </Box> | |
| )} | |
| <Stack spacing={data.isCluster ? 0.1 : 0} alignItems="center" sx={{ width: "100%" }}> | |
| <Typography | |
| sx={{ | |
| fontSize: data.isCluster ? 11 : 10.5, | |
| fontWeight: 700, | |
| lineHeight: 1.1, | |
| color: "#0b1020", | |
| letterSpacing: 0.05, | |
| maxWidth: "100%", | |
| overflow: "hidden", | |
| textOverflow: "ellipsis", | |
| whiteSpace: "nowrap", | |
| }} | |
| > | |
| {data.label} | |
| </Typography> | |
| {data.sublabel && ( | |
| <Typography | |
| sx={{ | |
| fontSize: 9, | |
| lineHeight: 1.1, | |
| color: "rgba(11,16,32,0.7)", | |
| fontWeight: 600, | |
| maxWidth: "100%", | |
| overflow: "hidden", | |
| textOverflow: "ellipsis", | |
| whiteSpace: "nowrap", | |
| }} | |
| > | |
| {data.sublabel} | |
| </Typography> | |
| )} | |
| </Stack> | |
| </Box> | |
| ); | |
| } | |
| const handleStyle = { | |
| background: "rgba(255,255,255,0.0)", | |
| border: "0", | |
| width: 1, | |
| height: 1, | |
| pointerEvents: "none" as const, | |
| }; | |
| /** | |
| * Side handles are used by skip-connections (residuals): an extra path is | |
| * routed along the right edge of the column so it bypasses the chain of | |
| * forward nodes instead of crossing them. They are visually muted (smaller, | |
| * lower opacity) since the user shouldn't interact with them — they exist | |
| * purely as endpoints for the long-range edges produced by `layout.ts`. | |
| */ | |
| const sideHandleStyle: React.CSSProperties = { | |
| background: "rgba(255,255,255,0.0)", | |
| border: "1px solid rgba(255,255,255,0.0)", | |
| width: 3, | |
| height: 3, | |
| pointerEvents: "none", | |
| }; | |
| function OpNodeImpl({ data, selected }: NodeProps<NodeData>) { | |
| return ( | |
| <> | |
| <Handle id="t" type="target" position={Position.Top} style={handleStyle} /> | |
| <Handle id="s-r" type="source" position={Position.Right} style={sideHandleStyle} /> | |
| <Handle id="t-r" type="target" position={Position.Right} style={sideHandleStyle} /> | |
| <Tooltip | |
| title={nodeTooltipContent(data)} | |
| placement="right" | |
| arrow | |
| enterDelay={250} | |
| enterNextDelay={150} | |
| > | |
| <Box sx={{ width: "100%", height: "100%" }}> | |
| <NodeBody data={data} selected={!!selected} /> | |
| </Box> | |
| </Tooltip> | |
| <Handle id="s" type="source" position={Position.Bottom} style={handleStyle} /> | |
| </> | |
| ); | |
| } | |
| function ClusterNodeImpl({ data, selected }: NodeProps<NodeData>) { | |
| return ( | |
| <> | |
| <Handle id="t" type="target" position={Position.Top} style={handleStyle} /> | |
| <Handle id="s-r" type="source" position={Position.Right} style={sideHandleStyle} /> | |
| <Handle id="t-r" type="target" position={Position.Right} style={sideHandleStyle} /> | |
| <Tooltip | |
| title={nodeTooltipContent(data)} | |
| placement="right" | |
| arrow | |
| enterDelay={250} | |
| enterNextDelay={150} | |
| > | |
| <Box sx={{ width: "100%", height: "100%" }}> | |
| <NodeBody data={data} selected={!!selected} /> | |
| </Box> | |
| </Tooltip> | |
| <Handle id="s" type="source" position={Position.Bottom} style={handleStyle} /> | |
| </> | |
| ); | |
| } | |
| /** | |
| * Visual container ("subflow parent") that wraps a module's children. It is | |
| * non-interactive (no handles, no tooltip) and intentionally translucent so | |
| * the children inside are the visual focus, while the border + header band | |
| * make the parent module read as a single hierarchical block. | |
| */ | |
| /** | |
| * Visual container styled as an HTML <fieldset>: a rounded rectangle with | |
| * a thin border, and a label tag laid OVER the top-left segment of that | |
| * border so the title looks "carved" into it. This is the layout | |
| * convention used by hfviewer.com and gives each module a clear visual | |
| * boundary without occupying interior space. | |
| */ | |
| function GroupNodeImpl({ data }: NodeProps<NodeData>) { | |
| const color = data.color ?? "#506280"; | |
| const repeatCount = data.repeatCount; | |
| const hasRepeat = typeof repeatCount === "number" && repeatCount > 1; | |
| return ( | |
| <Box | |
| sx={{ | |
| width: "100%", | |
| height: "100%", | |
| position: "relative", | |
| borderRadius: "8px", | |
| border: `1px solid ${color}aa`, | |
| background: "transparent", | |
| pointerEvents: "none", | |
| overflow: "visible", | |
| }} | |
| > | |
| <Box | |
| sx={{ | |
| position: "absolute", | |
| top: -6, | |
| left: 10, | |
| px: 0.5, | |
| background: "#0b1224", | |
| lineHeight: 1, | |
| }} | |
| > | |
| <Typography | |
| sx={{ | |
| fontSize: 9, | |
| fontWeight: 700, | |
| letterSpacing: 0.4, | |
| color: `${color}dd`, | |
| whiteSpace: "nowrap", | |
| pointerEvents: "none", | |
| lineHeight: 1, | |
| fontFamily: | |
| "ui-monospace, SFMono-Regular, Menlo, Consolas, monospace", | |
| }} | |
| > | |
| {data.label} | |
| </Typography> | |
| </Box> | |
| {hasRepeat && ( | |
| <Box | |
| sx={{ | |
| position: "absolute", | |
| top: "50%", | |
| left: "calc(100% + 8px)", | |
| transform: "translateY(-50%)", | |
| color: `${color}ee`, | |
| fontSize: 16, | |
| fontWeight: 700, | |
| letterSpacing: 0.4, | |
| lineHeight: 1, | |
| pointerEvents: "none", | |
| whiteSpace: "nowrap", | |
| textShadow: "0 1px 2px rgba(0,0,0,0.6)", | |
| }} | |
| > | |
| ×{repeatCount} | |
| </Box> | |
| )} | |
| </Box> | |
| ); | |
| } | |
| export const OpNode = memo(OpNodeImpl); | |
| export const ClusterNode = memo(ClusterNodeImpl); | |
| export const GroupNode = memo(GroupNodeImpl); | |
| export const nodeTypes = { | |
| op: OpNode, | |
| cluster: ClusterNode, | |
| group: GroupNode, | |
| } as const; | |