Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
from typing import Dict, List
import numpy as np
from sudoku4x4_11empty.shared_multi_output_policy import (
compute_set_precision_recall,
parse_values_json,
stage_i_consistent_values,
)
def triangular_number(n: int) -> float:
nn = max(0, int(n))
return float(nn * (nn + 1) // 2)
def score_prediction_text(
*,
text: str,
grid: np.ndarray,
solved: np.ndarray,
target_cell: tuple[int, int],
stage_i: int,
reward_good_value: float,
penalty_bad_value: float,
penalty_malformed: float,
penalty_empty: float,
penalty_singleton: float,
) -> Dict[str, float | List[int] | str]:
parsed = parse_values_json(text)
target_values = stage_i_consistent_values(grid, target_cell=target_cell, stage_i=stage_i)
solved_value = int(np.asarray(solved, dtype=int).reshape(4, 4)[int(target_cell[0]), int(target_cell[1])])
singleton_penalty = 0.0 if int(stage_i) >= 2 else float(penalty_singleton)
if not parsed.parse_ok:
return {
'reward': -float(penalty_malformed),
'parse_ok': 0.0,
'strict_canonical': 0.0,
'num_predicted_values': 0.0,
'num_i_consistent_values': 0.0,
'num_non_i_consistent_values': 0.0,
'includes_ground_truth': 0.0,
'value_precision': 0.0,
'value_recall': 0.0,
'exact_set_match': 0.0,
'predicted_values': [],
'target_values': [int(v) for v in target_values],
'format_error': 'parse_failed',
}
predicted_values = [int(v) for v in parsed.values]
target_set = set(int(v) for v in target_values)
num_good = sum(1 for v in predicted_values if v in target_set)
num_bad = sum(1 for v in predicted_values if v not in target_set)
reward = triangular_number(num_good) * float(reward_good_value) - float(num_bad) * float(penalty_bad_value)
if not predicted_values:
reward -= float(penalty_empty)
if len(predicted_values) == 1 and len(target_values) > 1:
reward -= singleton_penalty
precision, recall = compute_set_precision_recall(predicted_values, target_values)
return {
'reward': float(reward),
'parse_ok': 1.0,
'strict_canonical': 1.0 if parsed.strict_canonical else 0.0,
'num_predicted_values': float(len(predicted_values)),
'num_i_consistent_values': float(num_good),
'num_non_i_consistent_values': float(num_bad),
'includes_ground_truth': 1.0 if solved_value in predicted_values else 0.0,
'value_precision': float(precision),
'value_recall': float(recall),
'exact_set_match': 1.0 if set(predicted_values) == target_set else 0.0,
'predicted_values': predicted_values,
'target_values': [int(v) for v in target_values],
'format_error': '',
}