Spaces:
Sleeping
Sleeping
echoboi Claude Sonnet 4.6 commited on
Commit ·
5931b5f
1
Parent(s): 86a79eb
Add cell_errors, error_regions, common_error_patterns to submit_rule response
Browse filesTracks worst-performing test state during evaluation and computes:
- cell_errors: up to 10 wrong cells with 8-neighbor context
- error_regions: quadrant error distribution + worst state accuracy
- common_error_patterns: top-5 (predicted, expected, count) pairs
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -14,6 +14,7 @@ import random
|
|
| 14 |
import threading
|
| 15 |
import time
|
| 16 |
import uuid
|
|
|
|
| 17 |
from typing import Any
|
| 18 |
|
| 19 |
import numpy as np
|
|
@@ -269,18 +270,67 @@ def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)):
|
|
| 269 |
test_states = _generate_test_states(env, n=500, seed=eval_seed)
|
| 270 |
total_cell_acc = 0.0
|
| 271 |
exact_matches = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
for state in test_states:
|
| 273 |
expected = env.get_true_next(state)
|
| 274 |
try:
|
| 275 |
predicted = fn(state.copy())
|
| 276 |
if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
|
| 277 |
# Partial credit: fraction of cells predicted correctly
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
if np.array_equal(predicted, expected):
|
| 280 |
exact_matches += 1
|
| 281 |
except Exception:
|
| 282 |
pass
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
# functional_accuracy = mean cell-level accuracy across all test states
|
| 285 |
accuracy = total_cell_acc / len(test_states)
|
| 286 |
|
|
@@ -326,6 +376,9 @@ def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)):
|
|
| 326 |
"agent_dl": agent_dl,
|
| 327 |
"reference_dl": ref_dl,
|
| 328 |
"delta_dl": delta_dl,
|
|
|
|
|
|
|
|
|
|
| 329 |
})
|
| 330 |
|
| 331 |
|
|
|
|
| 14 |
import threading
|
| 15 |
import time
|
| 16 |
import uuid
|
| 17 |
+
from collections import Counter
|
| 18 |
from typing import Any
|
| 19 |
|
| 20 |
import numpy as np
|
|
|
|
| 270 |
test_states = _generate_test_states(env, n=500, seed=eval_seed)
|
| 271 |
total_cell_acc = 0.0
|
| 272 |
exact_matches = 0
|
| 273 |
+
worst_state_acc = 1.0
|
| 274 |
+
worst_pred = None
|
| 275 |
+
worst_exp = None
|
| 276 |
+
worst_inp = None
|
| 277 |
for state in test_states:
|
| 278 |
expected = env.get_true_next(state)
|
| 279 |
try:
|
| 280 |
predicted = fn(state.copy())
|
| 281 |
if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
|
| 282 |
# Partial credit: fraction of cells predicted correctly
|
| 283 |
+
state_acc = float(np.mean(predicted == expected))
|
| 284 |
+
total_cell_acc += state_acc
|
| 285 |
+
if state_acc < worst_state_acc:
|
| 286 |
+
worst_state_acc = state_acc
|
| 287 |
+
worst_pred = predicted
|
| 288 |
+
worst_exp = expected
|
| 289 |
+
worst_inp = state
|
| 290 |
if np.array_equal(predicted, expected):
|
| 291 |
exact_matches += 1
|
| 292 |
except Exception:
|
| 293 |
pass
|
| 294 |
|
| 295 |
+
# Compute cell-level diagnostics from worst-performing state
|
| 296 |
+
cell_errors = []
|
| 297 |
+
error_regions = {}
|
| 298 |
+
common_error_patterns = []
|
| 299 |
+
if worst_pred is not None:
|
| 300 |
+
rows, cols = worst_pred.shape
|
| 301 |
+
wrong_mask = worst_pred != worst_exp
|
| 302 |
+
wrong_indices = np.argwhere(wrong_mask)
|
| 303 |
+
for idx in wrong_indices[:10]:
|
| 304 |
+
r, c = int(idx[0]), int(idx[1])
|
| 305 |
+
cell_errors.append({
|
| 306 |
+
"pos": [r, c],
|
| 307 |
+
"center": int(worst_inp[r, c]),
|
| 308 |
+
"N": int(worst_inp[(r - 1) % rows, c]),
|
| 309 |
+
"S": int(worst_inp[(r + 1) % rows, c]),
|
| 310 |
+
"E": int(worst_inp[r, (c + 1) % cols]),
|
| 311 |
+
"W": int(worst_inp[r, (c - 1) % cols]),
|
| 312 |
+
"NW": int(worst_inp[(r - 1) % rows, (c - 1) % cols]),
|
| 313 |
+
"NE": int(worst_inp[(r - 1) % rows, (c + 1) % cols]),
|
| 314 |
+
"SW": int(worst_inp[(r + 1) % rows, (c - 1) % cols]),
|
| 315 |
+
"SE": int(worst_inp[(r + 1) % rows, (c + 1) % cols]),
|
| 316 |
+
"predicted": int(worst_pred[r, c]),
|
| 317 |
+
"expected": int(worst_exp[r, c]),
|
| 318 |
+
})
|
| 319 |
+
mid_r, mid_c = rows // 2, cols // 2
|
| 320 |
+
error_regions = {
|
| 321 |
+
"top_left": int(wrong_mask[:mid_r, :mid_c].sum()),
|
| 322 |
+
"top_right": int(wrong_mask[:mid_r, mid_c:].sum()),
|
| 323 |
+
"bottom_left": int(wrong_mask[mid_r:, :mid_c].sum()),
|
| 324 |
+
"bottom_right": int(wrong_mask[mid_r:, mid_c:].sum()),
|
| 325 |
+
"total": int(wrong_mask.sum()),
|
| 326 |
+
"worst_state_accuracy": round(worst_state_acc, 4),
|
| 327 |
+
}
|
| 328 |
+
wrong_pairs = list(zip(worst_pred[wrong_mask].tolist(), worst_exp[wrong_mask].tolist()))
|
| 329 |
+
common_error_patterns = [
|
| 330 |
+
{"predicted": p, "expected": e, "count": c}
|
| 331 |
+
for (p, e), c in Counter(wrong_pairs).most_common(5)
|
| 332 |
+
]
|
| 333 |
+
|
| 334 |
# functional_accuracy = mean cell-level accuracy across all test states
|
| 335 |
accuracy = total_cell_acc / len(test_states)
|
| 336 |
|
|
|
|
| 376 |
"agent_dl": agent_dl,
|
| 377 |
"reference_dl": ref_dl,
|
| 378 |
"delta_dl": delta_dl,
|
| 379 |
+
"cell_errors": cell_errors,
|
| 380 |
+
"error_regions": error_regions,
|
| 381 |
+
"common_error_patterns": common_error_patterns,
|
| 382 |
})
|
| 383 |
|
| 384 |
|