from __future__ import annotations import os import sys from typing import Dict, List import numpy as np CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) PARENT_DIR = os.path.dirname(CURRENT_DIR) if PARENT_DIR not in sys.path: sys.path.insert(0, PARENT_DIR) from multi_output_cell_policy.shared_multi_output_policy import ( compute_set_precision_recall, parse_values_json, stage_i_consistent_values, ) def triangular_number(n: int) -> float: nn = max(0, int(n)) return float(nn * (nn + 1) // 2) def score_prediction_text( *, text: str, grid: np.ndarray, solved: np.ndarray, target_cell: tuple[int, int], stage_i: int, reward_good_value: float, penalty_bad_value: float, penalty_malformed: float, penalty_empty: float, penalty_singleton: float, penalty_missing: float = 0.0, exact_match_bonus: float = 0.0, cardinality_mismatch_penalty: float = 0.0, ) -> Dict[str, float | List[int] | str]: parsed = parse_values_json(text) target_values = stage_i_consistent_values(grid, target_cell=target_cell, stage_i=stage_i) solved_value = int(np.asarray(solved, dtype=int).reshape(9, 9)[int(target_cell[0]), int(target_cell[1])]) # Legacy gating preserved: at stage>=2 the original singleton penalty is off by default. # Under-prediction pressure at stage>=2 is supplied by the new cardinality_mismatch_penalty # below (if > 0). At stage 1, both penalties may stack. singleton_penalty = 0.0 if int(stage_i) >= 2 else float(penalty_singleton) if not parsed.parse_ok: return { "reward": -float(penalty_malformed), "parse_ok": 0.0, "strict_canonical": 0.0, "num_predicted_values": 0.0, "num_i_consistent_values": 0.0, "num_non_i_consistent_values": 0.0, "num_missing_values": float(len(target_values)), "includes_ground_truth": 0.0, "value_precision": 0.0, "value_recall": 0.0, "exact_set_match": 0.0, "predicted_values": [], "target_values": [int(v) for v in target_values], "format_error": "parse_failed", } predicted_values = [int(v) for v in parsed.values] target_set = set(int(v) for v in target_values) num_good = sum(1 for v in predicted_values if v in target_set) num_bad = sum(1 for v in predicted_values if v not in target_set) num_missing = max(0, len(target_set) - num_good) is_exact = bool(predicted_values) and (set(predicted_values) == target_set) # Base reward: encourage larger all-good sets while making extra wrong values expensive. reward = triangular_number(num_good) * float(reward_good_value) - float(num_bad) * float( penalty_bad_value ) # Directly penalize missing target values so recall is part of the optimization signal. if num_missing > 0: reward -= float(num_missing) * float(penalty_missing) # Bonus only when the predicted set exactly matches the target (and is non-empty), # so the optimum strictly dominates partial supersets. if is_exact: reward += float(exact_match_bonus) if not predicted_values: reward -= float(penalty_empty) if len(predicted_values) == 1 and len(target_values) > 1: reward -= singleton_penalty # Stage-agnostic cardinality-mismatch pressure for multi-value targets. # Fires whenever the prediction has strictly fewer values than the target set # (the dominant failure mode for stage>=2 multi-value cells). if len(predicted_values) < len(target_values) and len(target_values) > 1: reward -= float(cardinality_mismatch_penalty) precision, recall = compute_set_precision_recall(predicted_values, target_values) return { "reward": float(reward), "parse_ok": 1.0, "strict_canonical": 1.0 if parsed.strict_canonical else 0.0, "num_predicted_values": float(len(predicted_values)), "num_i_consistent_values": float(num_good), "num_non_i_consistent_values": float(num_bad), "num_missing_values": float(num_missing), "includes_ground_truth": 1.0 if solved_value in predicted_values else 0.0, "value_precision": float(precision), "value_recall": float(recall), "exact_set_match": 1.0 if is_exact else 0.0, "predicted_values": predicted_values, "target_values": [int(v) for v in target_values], "format_error": "", }