File size: 5,419 Bytes
0573fbf
 
 
c7986a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0573fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
c7986a9
0573fbf
c7986a9
0573fbf
 
 
 
 
 
 
 
 
 
c7986a9
 
0573fbf
 
 
 
 
 
 
 
 
 
c7986a9
0573fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7986a9
0573fbf
 
 
 
c7986a9
 
0573fbf
c7986a9
 
0573fbf
 
 
c7986a9
0573fbf
 
c7986a9
 
 
 
 
 
 
0573fbf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import React, { useState } from 'react';
import { lossChartStyles } from '../theme';

// Exponential moving average. alpha controls smoothness:
//   alpha → 1   = no smoothing (output equals input)
//   alpha → 0   = heavy smoothing (output flat-ish line)
// Diffusion loss is intrinsically noisy because each step samples a random
// timestep with different difficulty, so a small alpha (heavy smoothing) is
// what makes the underlying trend visible.
const EMA_ALPHA = 0.06;

function smoothEMA(values, alpha = EMA_ALPHA) {
    if (values.length === 0) return [];
    const out = [values[0]];
    for (let i = 1; i < values.length; i++) {
        out.push(alpha * values[i] + (1 - alpha) * out[i - 1]);
    }
    return out;
}

export default function LossChart({ data, width = 600, height = 200 }) {
    const [hover, setHover] = useState(null);
    const padding = lossChartStyles.padding;
    const colors = lossChartStyles.colors;
    const axisFontSize = lossChartStyles.axisFontSize;
    const tooltip = lossChartStyles.tooltip;

    if (!data || data.length === 0) return null;

    const innerW = width - padding.left - padding.right;
    const innerH = height - padding.top - padding.bottom;

    const xs = data.map(d => d.step);
    const ys = data.map(d => d.loss);
    const smoothed = smoothEMA(ys);
    const xMin = Math.min(...xs);
    const xMax = Math.max(...xs);
    const yMin = Math.min(...ys);
    const yMax = Math.max(...ys);
    const xRange = xMax - xMin || 1;
    const yRange = yMax - yMin || 1;

    const xScale = v => padding.left + ((v - xMin) / xRange) * innerW;
    const yScale = v => padding.top + innerH - ((v - yMin) / yRange) * innerH;

    const points = data.map(d => `${xScale(d.step)},${yScale(d.loss)}`).join(' ');
    const smoothedPoints = data.map((d, i) => `${xScale(d.step)},${yScale(smoothed[i])}`).join(' ');

    const yTicks = 4;
    const xTicks = Math.min(5, data.length);

    const handleMove = (e) => {
        const rect = e.currentTarget.getBoundingClientRect();
        const px = ((e.clientX - rect.left) / rect.width) * width;
        let nearest = data[0];
        let bestDist = Infinity;
        for (const d of data) {
            const dist = Math.abs(xScale(d.step) - px);
            if (dist < bestDist) { bestDist = dist; nearest = d; }
        }
        setHover(nearest);
    };

    return (
        <svg
            viewBox={`0 0 ${width} ${height}`}
            preserveAspectRatio="none"
            style={lossChartStyles.svg}
            onMouseMove={handleMove}
            onMouseLeave={() => setHover(null)}
        >
            {Array.from({ length: yTicks + 1 }, (_, i) => {
                const v = yMin + (yRange * i) / yTicks;
                const y = yScale(v);
                return (
                    <g key={`y${i}`}>
                        <line x1={padding.left} x2={width - padding.right} y1={y} y2={y}
                            stroke={colors.grid} strokeDasharray="3 3" />
                        <text x={padding.left - 6} y={y + 4} textAnchor="end"
                            fontSize={axisFontSize} fill={colors.axis}>{v.toFixed(3)}</text>
                    </g>
                );
            })}

            {Array.from({ length: xTicks }, (_, i) => {
                const v = xMin + (xRange * i) / Math.max(xTicks - 1, 1);
                const x = xScale(v);
                return (
                    <text key={`x${i}`} x={x} y={height - padding.bottom + 16}
                          textAnchor="middle" fontSize={axisFontSize} fill={colors.axis}>
                        {Math.round(v)}
                    </text>
                );
            })}

            {/* Raw loss as a faint background trace (the "peaky" curve). */}
            <polyline fill="none" stroke={colors.line} strokeWidth="1" strokeOpacity="0.25" points={points} />

            {/* EMA-smoothed loss as the primary trace. */}
            <polyline fill="none" stroke={colors.line} strokeWidth="2.5" points={smoothedPoints} />

            {hover && (
                <g>
                    <line x1={xScale(hover.step)} x2={xScale(hover.step)}
                          y1={padding.top} y2={height - padding.bottom}
                          stroke={colors.axis} strokeDasharray="2 2" />
                    <circle cx={xScale(hover.step)} cy={yScale(hover.loss)} r="3" fill={colors.line} fillOpacity="0.4" />
                    <circle cx={xScale(hover.step)} cy={yScale(smoothed[data.indexOf(hover)] ?? hover.loss)} r="4" fill={colors.line} />
                    <g transform={`translate(${Math.min(xScale(hover.step) + 8, width - (tooltip.width + 10))}, ${padding.top + 6})`}>
                        <rect width={tooltip.width} height={tooltip.height + 12} rx={tooltip.rx} fill={colors.tooltipBg} stroke={colors.tooltipBorder} />
                        <text x={tooltip.textX} y={tooltip.timeY} fontSize={axisFontSize} fill={colors.tooltipText}>Step: {hover.step}</text>
                        <text x={tooltip.textX} y={tooltip.lossY} fontSize={axisFontSize} fill={colors.tooltipText}>Raw: {hover.loss.toFixed(4)}</text>
                        <text x={tooltip.textX} y={tooltip.lossY + 14} fontSize={axisFontSize} fill={colors.tooltipText}>Smoothed: {(smoothed[data.indexOf(hover)] ?? hover.loss).toFixed(4)}</text>
                    </g>
                </g>
            )}
        </svg>
    );
}