import React, { useState } from 'react'; import { lossChartStyles } from '../theme'; // Exponential moving average. alpha controls smoothness: // alpha → 1 = no smoothing (output equals input) // alpha → 0 = heavy smoothing (output flat-ish line) // Diffusion loss is intrinsically noisy because each step samples a random // timestep with different difficulty, so a small alpha (heavy smoothing) is // what makes the underlying trend visible. const EMA_ALPHA = 0.06; function smoothEMA(values, alpha = EMA_ALPHA) { if (values.length === 0) return []; const out = [values[0]]; for (let i = 1; i < values.length; i++) { out.push(alpha * values[i] + (1 - alpha) * out[i - 1]); } return out; } export default function LossChart({ data, width = 600, height = 200 }) { const [hover, setHover] = useState(null); const padding = lossChartStyles.padding; const colors = lossChartStyles.colors; const axisFontSize = lossChartStyles.axisFontSize; const tooltip = lossChartStyles.tooltip; if (!data || data.length === 0) return null; const innerW = width - padding.left - padding.right; const innerH = height - padding.top - padding.bottom; const xs = data.map(d => d.step); const ys = data.map(d => d.loss); const smoothed = smoothEMA(ys); const xMin = Math.min(...xs); const xMax = Math.max(...xs); const yMin = Math.min(...ys); const yMax = Math.max(...ys); const xRange = xMax - xMin || 1; const yRange = yMax - yMin || 1; const xScale = v => padding.left + ((v - xMin) / xRange) * innerW; const yScale = v => padding.top + innerH - ((v - yMin) / yRange) * innerH; const points = data.map(d => `${xScale(d.step)},${yScale(d.loss)}`).join(' '); const smoothedPoints = data.map((d, i) => `${xScale(d.step)},${yScale(smoothed[i])}`).join(' '); const yTicks = 4; const xTicks = Math.min(5, data.length); const handleMove = (e) => { const rect = e.currentTarget.getBoundingClientRect(); const px = ((e.clientX - rect.left) / rect.width) * width; let nearest = data[0]; let bestDist = Infinity; for (const d of data) { const dist = Math.abs(xScale(d.step) - px); if (dist < bestDist) { bestDist = dist; nearest = d; } } setHover(nearest); }; return ( setHover(null)} > {Array.from({ length: yTicks + 1 }, (_, i) => { const v = yMin + (yRange * i) / yTicks; const y = yScale(v); return ( {v.toFixed(3)} ); })} {Array.from({ length: xTicks }, (_, i) => { const v = xMin + (xRange * i) / Math.max(xTicks - 1, 1); const x = xScale(v); return ( {Math.round(v)} ); })} {/* Raw loss as a faint background trace (the "peaky" curve). */} {/* EMA-smoothed loss as the primary trace. */} {hover && ( Step: {hover.step} Raw: {hover.loss.toFixed(4)} Smoothed: {(smoothed[data.indexOf(hover)] ?? hover.loss).toFixed(4)} )} ); }