echoboi Claude Sonnet 4.6 commited on
Commit
5931b5f
·
1 Parent(s): 86a79eb

Add cell_errors, error_regions, common_error_patterns to submit_rule response

Browse files

Tracks worst-performing test state during evaluation and computes:
- cell_errors: up to 10 wrong cells with 8-neighbor context
- error_regions: quadrant error distribution + worst state accuracy
- common_error_patterns: top-5 (predicted, expected, count) pairs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +54 -1
app.py CHANGED
@@ -14,6 +14,7 @@ import random
14
  import threading
15
  import time
16
  import uuid
 
17
  from typing import Any
18
 
19
  import numpy as np
@@ -269,18 +270,67 @@ def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)):
269
  test_states = _generate_test_states(env, n=500, seed=eval_seed)
270
  total_cell_acc = 0.0
271
  exact_matches = 0
 
 
 
 
272
  for state in test_states:
273
  expected = env.get_true_next(state)
274
  try:
275
  predicted = fn(state.copy())
276
  if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
277
  # Partial credit: fraction of cells predicted correctly
278
- total_cell_acc += float(np.mean(predicted == expected))
 
 
 
 
 
 
279
  if np.array_equal(predicted, expected):
280
  exact_matches += 1
281
  except Exception:
282
  pass
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  # functional_accuracy = mean cell-level accuracy across all test states
285
  accuracy = total_cell_acc / len(test_states)
286
 
@@ -326,6 +376,9 @@ def submit_rule(session_id: str, body: SubmitRequest, _: None = Depends(_auth)):
326
  "agent_dl": agent_dl,
327
  "reference_dl": ref_dl,
328
  "delta_dl": delta_dl,
 
 
 
329
  })
330
 
331
 
 
14
  import threading
15
  import time
16
  import uuid
17
+ from collections import Counter
18
  from typing import Any
19
 
20
  import numpy as np
 
270
  test_states = _generate_test_states(env, n=500, seed=eval_seed)
271
  total_cell_acc = 0.0
272
  exact_matches = 0
273
+ worst_state_acc = 1.0
274
+ worst_pred = None
275
+ worst_exp = None
276
+ worst_inp = None
277
  for state in test_states:
278
  expected = env.get_true_next(state)
279
  try:
280
  predicted = fn(state.copy())
281
  if isinstance(predicted, np.ndarray) and predicted.shape == expected.shape:
282
  # Partial credit: fraction of cells predicted correctly
283
+ state_acc = float(np.mean(predicted == expected))
284
+ total_cell_acc += state_acc
285
+ if state_acc < worst_state_acc:
286
+ worst_state_acc = state_acc
287
+ worst_pred = predicted
288
+ worst_exp = expected
289
+ worst_inp = state
290
  if np.array_equal(predicted, expected):
291
  exact_matches += 1
292
  except Exception:
293
  pass
294
 
295
+ # Compute cell-level diagnostics from worst-performing state
296
+ cell_errors = []
297
+ error_regions = {}
298
+ common_error_patterns = []
299
+ if worst_pred is not None:
300
+ rows, cols = worst_pred.shape
301
+ wrong_mask = worst_pred != worst_exp
302
+ wrong_indices = np.argwhere(wrong_mask)
303
+ for idx in wrong_indices[:10]:
304
+ r, c = int(idx[0]), int(idx[1])
305
+ cell_errors.append({
306
+ "pos": [r, c],
307
+ "center": int(worst_inp[r, c]),
308
+ "N": int(worst_inp[(r - 1) % rows, c]),
309
+ "S": int(worst_inp[(r + 1) % rows, c]),
310
+ "E": int(worst_inp[r, (c + 1) % cols]),
311
+ "W": int(worst_inp[r, (c - 1) % cols]),
312
+ "NW": int(worst_inp[(r - 1) % rows, (c - 1) % cols]),
313
+ "NE": int(worst_inp[(r - 1) % rows, (c + 1) % cols]),
314
+ "SW": int(worst_inp[(r + 1) % rows, (c - 1) % cols]),
315
+ "SE": int(worst_inp[(r + 1) % rows, (c + 1) % cols]),
316
+ "predicted": int(worst_pred[r, c]),
317
+ "expected": int(worst_exp[r, c]),
318
+ })
319
+ mid_r, mid_c = rows // 2, cols // 2
320
+ error_regions = {
321
+ "top_left": int(wrong_mask[:mid_r, :mid_c].sum()),
322
+ "top_right": int(wrong_mask[:mid_r, mid_c:].sum()),
323
+ "bottom_left": int(wrong_mask[mid_r:, :mid_c].sum()),
324
+ "bottom_right": int(wrong_mask[mid_r:, mid_c:].sum()),
325
+ "total": int(wrong_mask.sum()),
326
+ "worst_state_accuracy": round(worst_state_acc, 4),
327
+ }
328
+ wrong_pairs = list(zip(worst_pred[wrong_mask].tolist(), worst_exp[wrong_mask].tolist()))
329
+ common_error_patterns = [
330
+ {"predicted": p, "expected": e, "count": c}
331
+ for (p, e), c in Counter(wrong_pairs).most_common(5)
332
+ ]
333
+
334
  # functional_accuracy = mean cell-level accuracy across all test states
335
  accuracy = total_cell_acc / len(test_states)
336
 
 
376
  "agent_dl": agent_dl,
377
  "reference_dl": ref_dl,
378
  "delta_dl": delta_dl,
379
+ "cell_errors": cell_errors,
380
+ "error_regions": error_regions,
381
+ "common_error_patterns": common_error_patterns,
382
  })
383
 
384