robot-folding / app /src /content /embeds /folding /subtask-heatmap.html
pepijn223's picture
pepijn223 HF Staff
Rename delta actions to relative actions throughout
9a36854 unverified
raw
history blame
9.44 kB
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<style>
:root { --bg: transparent; --text: #e8eaf0; --subtext: #8b8fa8; --grid: #2a2d3a; --border: #2a2d3a; }
* { box-sizing: border-box; margin: 0; padding: 0; }
body { background: var(--bg); font-family: system-ui, sans-serif; color: var(--text); }
.tooltip {
position: absolute; background: #1a1d27; border: 1px solid var(--border);
border-radius: 8px; padding: 10px 14px; font-size: 13px; pointer-events: none;
opacity: 0; transition: opacity .15s; z-index: 10; min-width: 220px;
box-shadow: 0 4px 16px rgba(0,0,0,.4);
}
.tooltip strong { display: block; margin-bottom: 5px; }
.tooltip-row { display: flex; justify-content: space-between; gap: 12px; margin-top: 3px; font-size: 12px; color: var(--subtext); }
.tooltip-row span:last-child { color: var(--text); font-weight: 600; }
.legend-bar { display: flex; align-items: center; gap: 8px; margin-top: 10px; font-size: 11px; color: var(--subtext); justify-content: center; }
.legend-gradient { height: 10px; width: 180px; border-radius: 5px; flex-shrink: 0; }
</style>
</head>
<body>
<div style="position:relative">
<svg id="hm-chart"></svg>
<div class="tooltip" id="hm-tooltip"></div>
</div>
<div class="legend-bar">
<span>Fast (0s)</span>
<canvas id="lgd" class="legend-gradient" width="180" height="10"></canvas>
<span>Slow (≥120s)</span>
</div>
<script>
function _initSubtaskHeatmap() {
const rawData = [
{label:"1.1 π0",series:"1",total_sr:40, times:[null, 19.2, 42.22, 14.33, 19.88, 27.25]},
{label:"1.2 π0.5",series:"1",total_sr:20, times:[50, 39.27, 41.5, 12.3, 13.75, 10.75]},
{label:"1.3 Relative",series:"1",total_sr:35, times:[null, 19.5, 44.2, 14.8, 30.33, 22.14]},
{label:"1.4 RABC low",series:"1",total_sr:15, times:[null, 20.8, 36.62, 10.0, 18.8, 12.67]},
{label:"1.5 RABC high",series:"1",total_sr:0, times:[240, 21.4, 100.0, null, null, null ]},
{label:"1.7 Rel+RABC",series:"1",total_sr:40, times:[157.5,19.33, 32.64, 8.9, 11.0, 23.38]},
{label:"2.1 HQ",series:"2",total_sr:40, times:[77.5, 11.08, 21.09, 5.45, 5.5, 11.5 ]},
{label:"2.2 HQ+RABC+Rel",series:"2",total_sr:75, times:[34.33,6.25, 12.31, 3.75, 5.31, 8.93 ]},
{label:"2.3 HQ+mirror",series:"2",total_sr:5, times:[49, 14.0, 23.71, 17.5, 11.0, 4.0 ]},
{label:"2.4 HQ chunk45",series:"2",total_sr:20, times:[120, 10.09, 41.18, 7.89, 7.33, 10.0 ]},
{label:"2.5 HQ+RABC+Rel★",series:"2",total_sr:90, times:[62.25,8.28, 12.0, 5.28, 5.22, 6.83 ]},
];
// Sort rows: best → worst by total_sr (heatmap: top = best)
const data = [...rawData].sort((a,b) => b.total_sr - a.total_sr);
const subtasks = ["Unfold","Fold 1","Fold 2","Fold 3","Fold 4","Rotation"];
const seriesColor = s => s === "2" ? "#f7934f" : "#4f8ef7";
// Color: green (fast) → yellow → red (slow), capped at 120s
const colorScale = d3.scaleSequential().domain([0, 120])
.interpolator(t => {
const stops = [[0,[46,200,138]],[0.4,[247,211,79]],[1,[220,60,60]]];
for (let i=0; i<stops.length-1; i++) {
const [t0,c0]=stops[i],[t1,c1]=stops[i+1];
if (t<=t1) {
const f=(t-t0)/(t1-t0);
return `rgb(${Math.round(c0[0]+(c1[0]-c0[0])*f)},${Math.round(c0[1]+(c1[1]-c0[1])*f)},${Math.round(c0[2]+(c1[2]-c0[2])*f)})`;
}
}
return "rgb(220,60,60)";
});
// Draw legend gradient
const canvas = document.getElementById("lgd");
const ctx = canvas.getContext("2d");
for (let i=0; i<180; i++) { ctx.fillStyle=colorScale(i/180*120); ctx.fillRect(i,0,1,10); }
const margin = {top:12, right:16, bottom:36, left:120};
const svg = d3.select("#hm-chart");
const container = svg.node().parentElement;
const tooltip = d3.select("#hm-tooltip");
function render() {
svg.selectAll("*").remove();
const W = container.clientWidth;
const cellW = Math.floor((W - margin.left - margin.right) / subtasks.length);
const cellH = Math.max(34, Math.min(44, cellW * 0.7));
const H = data.length * cellH + margin.top + margin.bottom;
svg.attr("width",W).attr("height",H);
const g = svg.append("g").attr("transform",`translate(${margin.left},${margin.top})`);
const gridW = cellW * subtasks.length;
const gridH = cellH * data.length;
// Column labels
g.selectAll(".col-lbl").data(subtasks).join("text")
.attr("x",(_,i)=>i*cellW+cellW/2).attr("y",gridH+22)
.attr("text-anchor","middle").attr("fill","#8b8fa8").attr("font-size",11)
.text(d=>d);
// Row labels + series stripe
data.forEach((d, ri) => {
// Series colour stripe on the left
g.append("rect")
.attr("x",-margin.left+2).attr("y",ri*cellH+2)
.attr("width",4).attr("height",cellH-4).attr("rx",2)
.attr("fill",seriesColor(d.series)).attr("opacity",0.9);
g.append("text")
.attr("x",-8).attr("y",ri*cellH+cellH/2)
.attr("text-anchor","end").attr("fill","#e8eaf0").attr("font-size",10).attr("font-weight","500")
.text(d.label);
g.append("text")
.attr("x",-8).attr("y",ri*cellH+cellH/2+11)
.attr("text-anchor","end").attr("fill","#8b8fa8").attr("font-size",8)
.text(d.total_sr+"% SR");
});
// Cells
data.forEach((d, ri) => {
d.times.forEach((val, ci) => {
const cellG = g.append("g").attr("transform",`translate(${ci*cellW},${ri*cellH})`);
if (val === null) {
cellG.append("rect").attr("width",cellW-3).attr("height",cellH-3).attr("rx",4)
.attr("fill","#1a1d27").attr("stroke","#2a2d3a");
cellG.append("text").attr("x",cellW/2-1.5).attr("y",cellH/2+3)
.attr("text-anchor","middle").attr("fill","#3a3d4a").attr("font-size",10).text("—");
} else {
const col = colorScale(Math.min(val,120));
cellG.append("rect").attr("width",cellW-3).attr("height",cellH-3).attr("rx",4)
.attr("fill",col).style("cursor","pointer")
.on("mousemove",function(event){
tooltip.style("opacity",1).html(`
<strong>Experiment ${d.label}${subtasks[ci]}</strong>
<div class="tooltip-row"><span>Avg time</span><span>${val.toFixed(1)}s</span></div>
<div class="tooltip-row"><span>Total SR</span><span>${d.total_sr}%</span></div>
<div class="tooltip-row"><span>Series</span><span>${d.series}</span></div>
`);
const bx=container.getBoundingClientRect();
const ex=event.clientX-bx.left, ey=event.clientY-bx.top;
tooltip.style("left",Math.min(ex+12,W-170)+"px").style("top",Math.max(ey-70,0)+"px");
})
.on("mouseleave",()=>tooltip.style("opacity",0));
const lum = d3.hsl(col).l;
cellG.append("text")
.attr("x",cellW/2-1.5).attr("y",cellH/2+4)
.attr("text-anchor","middle")
.attr("fill", lum > 0.52 ? "#1a1d27" : "#e8eaf0")
.attr("font-size",Math.max(8,Math.min(12,cellW*0.21))).attr("font-weight","600")
.text(val>=100 ? Math.round(val)+"s" : val.toFixed(1)+"s");
}
});
});
// Sorted annotation
g.append("text").attr("x",gridW).attr("y",-4).attr("text-anchor","end")
.attr("fill","#8b8fa8").attr("font-size",9)
.text("↓ sorted: best → worst");
}
render();
window.addEventListener("resize", render);
const EXPERIMENTS = {
"1.1 π0": { desc:"π0 · all data · 200k steps · MEAN_STD", note:"Base pi0 policy trained from scratch on the full dataset." },
"1.2 π0.5": { desc:"π0.5 · all data · 200k steps · MEAN_STD", note:"Upgraded to pi0.5 architecture, same data and steps." },
"1.3 Relative": { desc:"π0.5 · all data · 200k steps · Relative Actions · QUANTILES", note:"Adds Relative Actions on top of 1.2 — actions expressed relative to current state." },
"1.4 RABC low": { desc:"π0.5 · all data · 200k steps · RABC κ=0.01", note:"Selective Action Reward Model with low κ (≈ mean threshold, not very selective)." },
"1.5 RABC high": { desc:"π0.5 · all data · 200k steps · RABC κ=0.0215", note:"SARM with κ = mean + ½ std — more selective filtering than 1.4." },
"1.7 Rel+RABC": { desc:"π0.5 · all data · 200k steps · Relative Actions + RABC κ=0.0215 · QUANTILES", note:"Best of Series 1. Base checkpoint for 2.5." },
"2.1 HQ": { desc:"π0.5 · HQ data · 100k steps · fine-tune from 1.3", note:"Fine-tunes 1.3 on curated high-quality data only." },
"2.2 HQ+RABC+Rel": { desc:"π0.5 · HQ data · 100k steps · fine-tune from 1.3 + RABC κ=0.0265 + Relative Actions", note:"Adds RABC on high-quality fine-tune from 1.3." },
"2.3 HQ+mirror": { desc:"π0.5 · HQ + mirrored · 100k steps · fine-tune from 1.3 + Relative Actions + mirroring", note:"Augments the high-quality dataset with mirrored trajectories." },
"2.4 HQ chunk45": { desc:"π0.5 · HQ data · 100k steps · fine-tune from 1.3 · chunk=45", note:"Explores chunked action prediction (chunk=50, RTC size=50, execution horizon=35)." },
"2.5 HQ+RABC+Rel★": { desc:"π0.5 · HQ data · 100k steps · fine-tune from 1.7 + RABC κ=0.0265 + Relative Actions (best)", note:"Top performer. Best overall result." },
};
}
if (typeof d3 !== "undefined") {
_initSubtaskHeatmap();
} else {
var s = document.createElement("script");
s.src = "https://cdnjs.cloudflare.com/ajax/libs/d3/7.9.0/d3.min.js";
s.onload = _initSubtaskHeatmap;
document.head.appendChild(s);
}
</script>
</body>
</html>