Spaces:
Sleeping
Sleeping
File size: 7,293 Bytes
71dc210 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
Anti-Exploit Protections for Data-Centric RL Environment.
Centralised module for all anti-hacking checks:
1. Input truncation (>200 chars β truncate, -0.02 penalty)
2. Validate spam prevention (cooldown + diminishing returns)
3. Recommendation ID staleness check
4. Ground truth immutability assertion
5. Catastrophic data loss detection
6. Duplicate apply prevention
7. Max applies per session (3)
8. Episode wall-clock timeout (5 min β forced submit, -0.10)
9. Step timeout (5 sec β timeout obs, -0.05)
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Optional, Set
logger = logging.getLogger(__name__)
MAX_ACTION_CHARS = 200
MAX_APPLIES_PER_SESSION = 3
FREE_VALIDATES = 3
VALIDATE_COOLDOWN = 2 # must take this many non-validate actions before next validate
EPISODE_TIMEOUT_SECS = 5 * 60 # 5 minutes
STEP_TIMEOUT_SECS = 5 # 5 seconds per step
# ββ Exploit tracker (per episode state) ββββββββββββββββββββββββββββββββββββββ
@dataclass
class AntiExploitState:
# Validate tracking
validate_call_count: int = 0
steps_since_last_validate: int = 0 # cooldown counter
# Apply tracking
applied_ids_this_session: Set[int] = field(default_factory=set)
applies_this_session: int = 0
# Timing
episode_start_time: float = field(default_factory=time.time)
# Ground truth row count (set at reset)
ground_truth_row_count: int = 0
# ββ 1. Input truncation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def check_and_truncate_input(action: str) -> tuple[str, float, bool]:
"""
Returns (truncated_action, penalty, was_truncated).
Penalty is -0.02 if truncated, else 0.0.
"""
if len(action) > MAX_ACTION_CHARS:
logger.warning(
"Input truncated: original length %d > %d", len(action), MAX_ACTION_CHARS
)
return action[:MAX_ACTION_CHARS], -0.02, True
return action, 0.0, False
# ββ 2. Validate cooldown ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def check_validate_cooldown(state: AntiExploitState) -> tuple[bool, str]:
"""
Returns (allowed, error_message).
Validate is blocked if steps_since_last_validate < VALIDATE_COOLDOWN.
"""
if state.steps_since_last_validate < VALIDATE_COOLDOWN and state.validate_call_count > 0:
return False, (
f"Validate on cooldown. Take {VALIDATE_COOLDOWN - state.steps_since_last_validate} "
f"more action(s) before validating again."
)
return True, ""
def get_validate_reward(state: AntiExploitState) -> float:
"""Returns +0.02 for first FREE_VALIDATES calls, -0.01 thereafter."""
if state.validate_call_count < FREE_VALIDATES:
return 0.02
return -0.01
def record_validate(state: AntiExploitState):
state.validate_call_count += 1
state.steps_since_last_validate = 0
def record_non_validate_step(state: AntiExploitState):
state.steps_since_last_validate += 1
# ββ 3. Recommendation staleness βββββββββββββββββββββββββββββββββββββββββββββββ
def check_recommendation_staleness(
rec_id: int,
current_session_id: str,
recommendation_session_id: str,
) -> tuple[bool, str]:
"""Returns (is_fresh, error_message)."""
if current_session_id != recommendation_session_id:
return False, (
f"Stale recommendation ID {rec_id}. "
"Please re-query for fresh recommendations."
)
return True, ""
# ββ 4. Ground truth immutability ββββββββββββββββββββββββββββββββββββββββββββββ
def assert_ground_truth_intact(
ground_truth_len: int,
original_gt_len: int,
) -> tuple[bool, str]:
"""Asserts ground truth has not been mutated."""
if ground_truth_len != original_gt_len:
msg = (
f"INTEGRITY VIOLATION: ground_truth row count changed "
f"({original_gt_len} β {ground_truth_len}). This should never happen."
)
logger.critical(msg)
return False, msg
return True, ""
# ββ 5. Catastrophic data loss βββββββββββββββββββββββββββββββββββββββββββββββββ
def check_catastrophic_data_loss(
current_rows: int,
original_rows: int,
) -> tuple[bool, str]:
"""Returns (is_catastrophic, message)."""
ratio = current_rows / max(original_rows, 1)
if ratio < 0.50:
msg = (
f"CATASTROPHIC DATA LOSS: only {current_rows}/{original_rows} rows remain "
f"({ratio*100:.1f}%). Episode terminated."
)
logger.error(msg)
return True, msg
return False, ""
# ββ 6 & 7. Duplicate apply and session limit ββββββββββββββββββββββββββββββββββ
def check_apply_allowed(
rec_id: int,
state: AntiExploitState,
) -> tuple[bool, str]:
"""
Returns (allowed, error_message).
Blocks: duplicate ID in session, or session apply limit reached.
"""
if state.applies_this_session >= MAX_APPLIES_PER_SESSION:
return False, (
f"Max {MAX_APPLIES_PER_SESSION} applies per query session reached. "
"Please re-query for more options."
)
if rec_id in state.applied_ids_this_session:
return False, (
f"Recommendation {rec_id} has already been applied this session. "
"Duplicate apply not allowed."
)
return True, ""
def record_apply(rec_id: int, state: AntiExploitState):
state.applied_ids_this_session.add(rec_id)
state.applies_this_session += 1
def reset_session_apply_state(state: AntiExploitState):
"""Call this whenever a new query_X command resets the session."""
state.applied_ids_this_session = set()
state.applies_this_session = 0
# ββ 8. Episode timeout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def check_episode_timeout(state: AntiExploitState) -> tuple[bool, str]:
elapsed = time.time() - state.episode_start_time
if elapsed > EPISODE_TIMEOUT_SECS:
msg = (
f"Episode wall-clock timeout ({elapsed:.0f}s > {EPISODE_TIMEOUT_SECS}s). "
"Forcing submit. Penalty: -0.10."
)
logger.warning(msg)
return True, msg
return False, ""
# ββ 9. Step timeout context manager ββββββββββββββββββββββββββββββββββββββββββ
class StepTimeoutError(Exception):
pass
def validate_calls_remaining(state: AntiExploitState) -> int:
return max(0, FREE_VALIDATES - state.validate_call_count)
|