File size: 4,517 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations

import os
import sys
from typing import Dict, List

import numpy as np

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(CURRENT_DIR)
if PARENT_DIR not in sys.path:
    sys.path.insert(0, PARENT_DIR)

from multi_output_cell_policy.shared_multi_output_policy import (
    compute_set_precision_recall,
    parse_values_json,
    stage_i_consistent_values,
)


def triangular_number(n: int) -> float:
    nn = max(0, int(n))
    return float(nn * (nn + 1) // 2)


def score_prediction_text(
    *,
    text: str,
    grid: np.ndarray,
    solved: np.ndarray,
    target_cell: tuple[int, int],
    stage_i: int,
    reward_good_value: float,
    penalty_bad_value: float,
    penalty_malformed: float,
    penalty_empty: float,
    penalty_singleton: float,
    penalty_missing: float = 0.0,
    exact_match_bonus: float = 0.0,
    cardinality_mismatch_penalty: float = 0.0,
) -> Dict[str, float | List[int] | str]:
    parsed = parse_values_json(text)
    target_values = stage_i_consistent_values(grid, target_cell=target_cell, stage_i=stage_i)
    solved_value = int(np.asarray(solved, dtype=int).reshape(9, 9)[int(target_cell[0]), int(target_cell[1])])
    # Legacy gating preserved: at stage>=2 the original singleton penalty is off by default.
    # Under-prediction pressure at stage>=2 is supplied by the new cardinality_mismatch_penalty
    # below (if > 0). At stage 1, both penalties may stack.
    singleton_penalty = 0.0 if int(stage_i) >= 2 else float(penalty_singleton)

    if not parsed.parse_ok:
        return {
            "reward": -float(penalty_malformed),
            "parse_ok": 0.0,
            "strict_canonical": 0.0,
            "num_predicted_values": 0.0,
            "num_i_consistent_values": 0.0,
            "num_non_i_consistent_values": 0.0,
            "num_missing_values": float(len(target_values)),
            "includes_ground_truth": 0.0,
            "value_precision": 0.0,
            "value_recall": 0.0,
            "exact_set_match": 0.0,
            "predicted_values": [],
            "target_values": [int(v) for v in target_values],
            "format_error": "parse_failed",
        }

    predicted_values = [int(v) for v in parsed.values]
    target_set = set(int(v) for v in target_values)
    num_good = sum(1 for v in predicted_values if v in target_set)
    num_bad = sum(1 for v in predicted_values if v not in target_set)
    num_missing = max(0, len(target_set) - num_good)
    is_exact = bool(predicted_values) and (set(predicted_values) == target_set)

    # Base reward: encourage larger all-good sets while making extra wrong values expensive.
    reward = triangular_number(num_good) * float(reward_good_value) - float(num_bad) * float(
        penalty_bad_value
    )
    # Directly penalize missing target values so recall is part of the optimization signal.
    if num_missing > 0:
        reward -= float(num_missing) * float(penalty_missing)
    # Bonus only when the predicted set exactly matches the target (and is non-empty),
    # so the optimum strictly dominates partial supersets.
    if is_exact:
        reward += float(exact_match_bonus)
    if not predicted_values:
        reward -= float(penalty_empty)
    if len(predicted_values) == 1 and len(target_values) > 1:
        reward -= singleton_penalty
    # Stage-agnostic cardinality-mismatch pressure for multi-value targets.
    # Fires whenever the prediction has strictly fewer values than the target set
    # (the dominant failure mode for stage>=2 multi-value cells).
    if len(predicted_values) < len(target_values) and len(target_values) > 1:
        reward -= float(cardinality_mismatch_penalty)

    precision, recall = compute_set_precision_recall(predicted_values, target_values)
    return {
        "reward": float(reward),
        "parse_ok": 1.0,
        "strict_canonical": 1.0 if parsed.strict_canonical else 0.0,
        "num_predicted_values": float(len(predicted_values)),
        "num_i_consistent_values": float(num_good),
        "num_non_i_consistent_values": float(num_bad),
        "num_missing_values": float(num_missing),
        "includes_ground_truth": 1.0 if solved_value in predicted_values else 0.0,
        "value_precision": float(precision),
        "value_recall": float(recall),
        "exact_set_match": 1.0 if is_exact else 0.0,
        "predicted_values": predicted_values,
        "target_values": [int(v) for v in target_values],
        "format_error": "",
    }