File size: 7,293 Bytes
71dc210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Anti-Exploit Protections for Data-Centric RL Environment.

Centralised module for all anti-hacking checks:
  1. Input truncation (>200 chars β†’ truncate, -0.02 penalty)
  2. Validate spam prevention (cooldown + diminishing returns)
  3. Recommendation ID staleness check
  4. Ground truth immutability assertion
  5. Catastrophic data loss detection
  6. Duplicate apply prevention
  7. Max applies per session (3)
  8. Episode wall-clock timeout (5 min β†’ forced submit, -0.10)
  9. Step timeout (5 sec β†’ timeout obs, -0.05)
"""

import logging
import time
from dataclasses import dataclass, field
from typing import Optional, Set

logger = logging.getLogger(__name__)

MAX_ACTION_CHARS = 200
MAX_APPLIES_PER_SESSION = 3
FREE_VALIDATES = 3
VALIDATE_COOLDOWN = 2            # must take this many non-validate actions before next validate
EPISODE_TIMEOUT_SECS = 5 * 60   # 5 minutes
STEP_TIMEOUT_SECS = 5            # 5 seconds per step


# ── Exploit tracker (per episode state) ──────────────────────────────────────

@dataclass
class AntiExploitState:
    # Validate tracking
    validate_call_count: int = 0
    steps_since_last_validate: int = 0   # cooldown counter

    # Apply tracking
    applied_ids_this_session: Set[int] = field(default_factory=set)
    applies_this_session: int = 0

    # Timing
    episode_start_time: float = field(default_factory=time.time)

    # Ground truth row count (set at reset)
    ground_truth_row_count: int = 0


# ── 1. Input truncation ───────────────────────────────────────────────────────

def check_and_truncate_input(action: str) -> tuple[str, float, bool]:
    """
    Returns (truncated_action, penalty, was_truncated).
    Penalty is -0.02 if truncated, else 0.0.
    """
    if len(action) > MAX_ACTION_CHARS:
        logger.warning(
            "Input truncated: original length %d > %d", len(action), MAX_ACTION_CHARS
        )
        return action[:MAX_ACTION_CHARS], -0.02, True
    return action, 0.0, False


# ── 2. Validate cooldown ──────────────────────────────────────────────────────

def check_validate_cooldown(state: AntiExploitState) -> tuple[bool, str]:
    """
    Returns (allowed, error_message).
    Validate is blocked if steps_since_last_validate < VALIDATE_COOLDOWN.
    """
    if state.steps_since_last_validate < VALIDATE_COOLDOWN and state.validate_call_count > 0:
        return False, (
            f"Validate on cooldown. Take {VALIDATE_COOLDOWN - state.steps_since_last_validate} "
            f"more action(s) before validating again."
        )
    return True, ""


def get_validate_reward(state: AntiExploitState) -> float:
    """Returns +0.02 for first FREE_VALIDATES calls, -0.01 thereafter."""
    if state.validate_call_count < FREE_VALIDATES:
        return 0.02
    return -0.01


def record_validate(state: AntiExploitState):
    state.validate_call_count += 1
    state.steps_since_last_validate = 0


def record_non_validate_step(state: AntiExploitState):
    state.steps_since_last_validate += 1


# ── 3. Recommendation staleness ───────────────────────────────────────────────

def check_recommendation_staleness(
    rec_id: int,
    current_session_id: str,
    recommendation_session_id: str,
) -> tuple[bool, str]:
    """Returns (is_fresh, error_message)."""
    if current_session_id != recommendation_session_id:
        return False, (
            f"Stale recommendation ID {rec_id}. "
            "Please re-query for fresh recommendations."
        )
    return True, ""


# ── 4. Ground truth immutability ──────────────────────────────────────────────

def assert_ground_truth_intact(
    ground_truth_len: int,
    original_gt_len: int,
) -> tuple[bool, str]:
    """Asserts ground truth has not been mutated."""
    if ground_truth_len != original_gt_len:
        msg = (
            f"INTEGRITY VIOLATION: ground_truth row count changed "
            f"({original_gt_len} β†’ {ground_truth_len}). This should never happen."
        )
        logger.critical(msg)
        return False, msg
    return True, ""


# ── 5. Catastrophic data loss ─────────────────────────────────────────────────

def check_catastrophic_data_loss(
    current_rows: int,
    original_rows: int,
) -> tuple[bool, str]:
    """Returns (is_catastrophic, message)."""
    ratio = current_rows / max(original_rows, 1)
    if ratio < 0.50:
        msg = (
            f"CATASTROPHIC DATA LOSS: only {current_rows}/{original_rows} rows remain "
            f"({ratio*100:.1f}%). Episode terminated."
        )
        logger.error(msg)
        return True, msg
    return False, ""


# ── 6 & 7. Duplicate apply and session limit ──────────────────────────────────

def check_apply_allowed(
    rec_id: int,
    state: AntiExploitState,
) -> tuple[bool, str]:
    """
    Returns (allowed, error_message).
    Blocks: duplicate ID in session, or session apply limit reached.
    """
    if state.applies_this_session >= MAX_APPLIES_PER_SESSION:
        return False, (
            f"Max {MAX_APPLIES_PER_SESSION} applies per query session reached. "
            "Please re-query for more options."
        )
    if rec_id in state.applied_ids_this_session:
        return False, (
            f"Recommendation {rec_id} has already been applied this session. "
            "Duplicate apply not allowed."
        )
    return True, ""


def record_apply(rec_id: int, state: AntiExploitState):
    state.applied_ids_this_session.add(rec_id)
    state.applies_this_session += 1


def reset_session_apply_state(state: AntiExploitState):
    """Call this whenever a new query_X command resets the session."""
    state.applied_ids_this_session = set()
    state.applies_this_session = 0


# ── 8. Episode timeout ────────────────────────────────────────────────────────

def check_episode_timeout(state: AntiExploitState) -> tuple[bool, str]:
    elapsed = time.time() - state.episode_start_time
    if elapsed > EPISODE_TIMEOUT_SECS:
        msg = (
            f"Episode wall-clock timeout ({elapsed:.0f}s > {EPISODE_TIMEOUT_SECS}s). "
            "Forcing submit. Penalty: -0.10."
        )
        logger.warning(msg)
        return True, msg
    return False, ""


# ── 9. Step timeout context manager ──────────────────────────────────────────

class StepTimeoutError(Exception):
    pass


def validate_calls_remaining(state: AntiExploitState) -> int:
    return max(0, FREE_VALIDATES - state.validate_call_count)