Spaces:
Sleeping
Sleeping
CarolinePascal commited on
fix(lag computation): fixing lag computation by switching from delta states - raw actions correlation to delta states - delta actions correlation
Browse files
src/app/[org]/[dataset]/[episode]/fetch-data.ts
CHANGED
|
@@ -2082,14 +2082,12 @@ export async function loadCrossEpisodeActionVariance(
|
|
| 2082 |
|
| 2083 |
for (let pi = 0; pi < pairs.length; pi++) {
|
| 2084 |
const [ai, si] = pairs[pi];
|
| 2085 |
-
const
|
| 2086 |
-
const sDeltas = Array.from(
|
| 2087 |
-
|
| 2088 |
-
|
| 2089 |
-
);
|
| 2090 |
-
const
|
| 2091 |
-
const aM = aVals.slice(0, effN).reduce((a, b) => a + b, 0) / effN;
|
| 2092 |
-
const sM = sDeltas.slice(0, effN).reduce((a, b) => a + b, 0) / effN;
|
| 2093 |
|
| 2094 |
for (let li = 0; li < numLags; li++) {
|
| 2095 |
const lag = -maxLag + li;
|
|
@@ -2098,8 +2096,8 @@ export async function loadCrossEpisodeActionVariance(
|
|
| 2098 |
sV = 0;
|
| 2099 |
for (let t = 0; t < effN; t++) {
|
| 2100 |
const sIdx = t + lag;
|
| 2101 |
-
if (sIdx < 0 || sIdx >=
|
| 2102 |
-
const a =
|
| 2103 |
s = sDeltas[sIdx] - sM;
|
| 2104 |
sum += a * s;
|
| 2105 |
aV += a * a;
|
|
|
|
| 2082 |
|
| 2083 |
for (let pi = 0; pi < pairs.length; pi++) {
|
| 2084 |
const [ai, si] = pairs[pi];
|
| 2085 |
+
const aDeltas = Array.from({ length: n - 1 }, (_, t) => (actions[t + 1][ai] ?? 0) - (actions[t][ai] ?? 0));
|
| 2086 |
+
const sDeltas = Array.from({ length: n - 1 }, (_, t) => (states[t + 1][si] ?? 0) - (states[t][si] ?? 0));
|
| 2087 |
+
const effN = aDeltas.length;
|
| 2088 |
+
if (effN < 4) continue;
|
| 2089 |
+
const aM = aDeltas.reduce((a, b) => a + b, 0) / effN;
|
| 2090 |
+
const sM = sDeltas.reduce((a, b) => a + b, 0) / effN;
|
|
|
|
|
|
|
| 2091 |
|
| 2092 |
for (let li = 0; li < numLags; li++) {
|
| 2093 |
const lag = -maxLag + li;
|
|
|
|
| 2096 |
sV = 0;
|
| 2097 |
for (let t = 0; t < effN; t++) {
|
| 2098 |
const sIdx = t + lag;
|
| 2099 |
+
if (sIdx < 0 || sIdx >= effN) continue;
|
| 2100 |
+
const a = aDeltas[t] - aM,
|
| 2101 |
s = sDeltas[sIdx] - sM;
|
| 2102 |
sum += a * s;
|
| 2103 |
aV += a * a;
|
src/components/action-insights-panel.tsx
CHANGED
|
@@ -1281,15 +1281,14 @@ function StateActionAlignmentSection({
|
|
| 1281 |
}
|
| 1282 |
if (pairs.length === 0) return null;
|
| 1283 |
|
| 1284 |
-
// Per-pair cross-correlation
|
| 1285 |
const pairCorrs: number[][] = [];
|
| 1286 |
for (const [aKey, sKey] of pairs) {
|
| 1287 |
-
const
|
| 1288 |
-
const sDeltas = data
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
const
|
| 1292 |
-
const aM = aVals.slice(0, n).reduce((a, b) => a + b, 0) / n;
|
| 1293 |
const sM = sDeltas.slice(0, n).reduce((a, b) => a + b, 0) / n;
|
| 1294 |
|
| 1295 |
const corrs: number[] = [];
|
|
@@ -1299,8 +1298,8 @@ function StateActionAlignmentSection({
|
|
| 1299 |
sV = 0;
|
| 1300 |
for (let t = 0; t < n; t++) {
|
| 1301 |
const sIdx = t + lag;
|
| 1302 |
-
if (sIdx < 0 || sIdx >=
|
| 1303 |
-
const a =
|
| 1304 |
s = sDeltas[sIdx] - sM;
|
| 1305 |
sum += a * s;
|
| 1306 |
aV += a * a;
|
|
@@ -1407,7 +1406,7 @@ function StateActionAlignmentSection({
|
|
| 1407 |
</h3>
|
| 1408 |
<InfoToggle>
|
| 1409 |
<p className="text-xs text-slate-400">
|
| 1410 |
-
Per-dimension cross-correlation between action<sub>d</sub>(t) and
|
| 1411 |
Δstate<sub>d</sub>(t+lag), aggregated as
|
| 1412 |
<span className="text-orange-400"> max</span>,{" "}
|
| 1413 |
<span className="text-slate-200">mean</span>, and
|
|
|
|
| 1281 |
}
|
| 1282 |
if (pairs.length === 0) return null;
|
| 1283 |
|
| 1284 |
+
// Per-pair cross-correlation (Δaction vs Δstate)
|
| 1285 |
const pairCorrs: number[][] = [];
|
| 1286 |
for (const [aKey, sKey] of pairs) {
|
| 1287 |
+
const aDeltas = data.slice(1).map((row, i) => (row[aKey] ?? 0) - (data[i][aKey] ?? 0));
|
| 1288 |
+
const sDeltas = data.slice(1).map((row, i) => (row[sKey] ?? 0) - (data[i][sKey] ?? 0));
|
| 1289 |
+
const n = Math.min(aDeltas.length, sDeltas.length);
|
| 1290 |
+
if (n < 4) { pairCorrs.push(Array(2 * maxLag + 1).fill(0)); continue; }
|
| 1291 |
+
const aM = aDeltas.slice(0, n).reduce((a, b) => a + b, 0) / n;
|
|
|
|
| 1292 |
const sM = sDeltas.slice(0, n).reduce((a, b) => a + b, 0) / n;
|
| 1293 |
|
| 1294 |
const corrs: number[] = [];
|
|
|
|
| 1298 |
sV = 0;
|
| 1299 |
for (let t = 0; t < n; t++) {
|
| 1300 |
const sIdx = t + lag;
|
| 1301 |
+
if (sIdx < 0 || sIdx >= n) continue;
|
| 1302 |
+
const a = aDeltas[t] - aM,
|
| 1303 |
s = sDeltas[sIdx] - sM;
|
| 1304 |
sum += a * s;
|
| 1305 |
aV += a * a;
|
|
|
|
| 1406 |
</h3>
|
| 1407 |
<InfoToggle>
|
| 1408 |
<p className="text-xs text-slate-400">
|
| 1409 |
+
Per-dimension cross-correlation between Δaction<sub>d</sub>(t) and
|
| 1410 |
Δstate<sub>d</sub>(t+lag), aggregated as
|
| 1411 |
<span className="text-orange-400"> max</span>,{" "}
|
| 1412 |
<span className="text-slate-200">mean</span>, and
|