EntropyEnv / server /router.py
immortalindeed's picture
Fix dep_hard Counter bug, add fatal error handling, update README with 14-model benchmark
3466d21
# server/router.py
# Central dispatcher. Routes validated actions to the correct domain grader.
#
# KEY FIX: The _check_done() mastery condition was firing after just 2 steps
# if avg_reward >= 0.90. This caused:
# - sec_easy: identify_vulnerability scores 0.99 → avg = 0.99 → done=True immediately
# - dep_easy, cli_easy: same problem — 1-step episodes ending with 0.99
#
# The mastery condition is now DISABLED. Done is determined by:
# 1. max_steps reached (hard limit)
# 2. required_sequence fully completed (all actions in sequence done)
# 3. completion_threshold met AND min_actions satisfied
#
# This forces multi-step tasks to actually run all required steps,
# and prevents easy tasks from short-circuiting at step 1.
from typing import Dict
from .session import SessionState
from .graders import security_grader, dependency_grader, clinical_grader
GRADERS = {
'security': security_grader,
'dependency': dependency_grader,
'clinical': clinical_grader,
}
def route_step(session: SessionState, action: Dict) -> Dict:
"""Route a validated action to the correct grader and return enriched result."""
grader = GRADERS.get(session.task_type)
if not grader:
return {
'reward': 0.01,
'done': True,
'observation': {'error': f'Unknown task_type: {session.task_type}'},
}
reward = grader.grade(action, session)
case = session.task_case
max_steps = case.get('max_steps', 8)
done = _check_done(session, action, reward, max_steps)
obs = _build_step_obs(session, action, reward, done)
score_details = _compute_score_details(action, session)
obs['score_breakdown'] = score_details
return {
'episode_id': session.episode_id,
'step_count': session.step_count + 1,
'reward': round(float(reward), 4),
'done': bool(done),
'observation': obs,
'score_details': score_details,
}
def _check_done(session: SessionState, action: Dict, reward: float, max_steps: int) -> bool:
"""
Determine if the episode should end.
Rules (in priority order):
1. Hard limit: max_steps reached → always done
2. min_actions not yet reached → never done early
3. Required sequence: each action in required_sequence must appear
at least as many times as it appears in the list → done
(e.g. ['migrate_api', 'migrate_api'] requires 2 migrate_api calls)
4. Single-step tasks (min_actions=1, no required_sequence): threshold met → done
5. Otherwise: not done
BUG FIX: Previously used `all(a in all_actions ...)` which treated
['migrate_api', 'migrate_api'] as satisfied after just 1 migrate_api call
because Python `in` checks set membership, not count.
Now uses Counter to check that each action appears enough times.
"""
next_step = session.step_count + 1
case = session.task_case
done_conditions = case.get('done_conditions', {})
min_actions = done_conditions.get('min_actions', 1)
required_seq = done_conditions.get('required_sequence', [])
# Rule 1: Hard limit — always terminates
if next_step >= max_steps:
return True
# Build the full action history including the current action
all_actions = session.last_actions + [action.get('action_type', '')]
# Rule 2: min_actions guard — episode cannot end before this many steps
if next_step < min_actions:
return False
# Rule 3: Required sequence check using COUNTS not set membership
# This correctly handles repeated actions like ['migrate_api', 'migrate_api']
if required_seq:
from collections import Counter
required_counts = Counter(required_seq)
actual_counts = Counter(all_actions)
# Every required action must appear at least as many times as required
seq_complete = all(
actual_counts[act] >= count
for act, count in required_counts.items()
)
if seq_complete:
return True
return False # required_seq defined but not complete → keep going
# Rule 4: Single-step tasks with no required sequence — threshold met
if min_actions == 1:
threshold = case.get('completion_threshold', 0.85)
if reward >= threshold:
return True
return False
def build_initial_obs(session: SessionState) -> dict:
"""Build the initial observation returned by /reset."""
case = session.task_case
task_type = session.task_type
task_id = session.task_id
obs = {
'task_type': task_type,
'task_id': task_id,
'task_subtype': case.get('task_subtype', 'standard'),
'task_description': case.get('task_description', ''),
'turn': 0,
'done': False,
}
if task_type == 'security':
obs['code_snippet'] = case.get('tool_call', '')
obs['reviewer_feedback'] = None
obs['available_actions'] = [
{'name': 'identify_vulnerability',
'params': ['vuln_type:str', 'cvss_score:float', 'severity:str', 'affected_line:int']},
{'name': 'propose_fix',
'params': ['fix_code:str', 'explanation:str']},
{'name': 'revise_fix',
'params': ['fix_code:str', 'addressed_feedback:str']},
]
elif task_type == 'dependency':
obs['code_snippet'] = case.get('code_snippet', '')
subtype = case.get('task_subtype', '')
if subtype == 'flag':
obs['requirements'] = case.get('requirements', {})
obs['available_actions'] = [
{'name': 'flag_outdated',
'params': ['packages:dict', 'deprecated_api:str|null', 'replacement:str|null']},
]
elif subtype == 'resolve':
obs['conflict_packages'] = case.get('conflict_packages', [])
obs['compatibility_matrix'] = case.get('compatibility_matrix', {})
obs['current_requirements'] = case.get('requirements', {})
obs['available_actions'] = [
{'name': 'resolve_conflict',
'params': ['packages:dict', 'reasoning:str']},
]
elif subtype == 'migrate':
obs['graph_break_report'] = case.get('graph_break_report', case.get('break_descriptions', []))
obs['available_actions'] = [
{'name': 'migrate_api',
'params': ['completed_items:list', 'code_changes:dict']},
{'name': 'validate_tree',
'params': ['completed_items:list']},
]
elif task_type == 'clinical':
obs['patient_id'] = case.get('patient_id', '')
obs['events'] = case.get('events', case.get('patient_events', []))
obs['available_steps'] = case.get('available_steps', [])
if task_id in ('cli_medium', 'cli_hard'):
obs['dependency_graph'] = case.get('dependency_graph', {})
obs['available_actions'] = [
{'name': 'detect_gap',
'params': ['missing_steps:list', 'risk_level:str']},
{'name': 'rank_issues',
'params': ['priority_order:list']},
{'name': 'order_steps',
'params': ['recovery_steps:list']},
]
return obs
def _build_step_obs(session: SessionState, action: Dict, reward: float, done: bool) -> Dict:
"""Build observation returned after each step()."""
case = session.task_case
task_type = session.task_type
obs = {
'task_type': task_type,
'task_id': session.task_id,
'task_subtype': case.get('task_subtype', 'standard'),
'turn': session.step_count + 1,
'done': done,
'last_reward': round(reward, 4),
}
if done:
obs['message'] = 'Episode complete.'
return obs
if task_type == 'security':
obs['task_description'] = case.get('task_description', '')
obs['code_snippet'] = case.get('tool_call', '')
atype = action.get('action_type', '')
if atype == 'propose_fix':
fb = case.get('reviewer_feedback', '')
if fb:
obs['reviewer_feedback'] = fb
elif atype == 'revise_fix':
fb_seq = case.get('reviewer_feedback_sequence', [])
if fb_seq:
fb_idx = min(len(session.history), len(fb_seq) - 1)
if fb_idx >= 0:
obs['reviewer_feedback'] = fb_seq[fb_idx]
obs['available_actions'] = [
{'name': 'identify_vulnerability',
'params': ['vuln_type:str', 'cvss_score:float', 'severity:str', 'affected_line:int']},
{'name': 'propose_fix',
'params': ['fix_code:str', 'explanation:str']},
{'name': 'revise_fix',
'params': ['fix_code:str', 'addressed_feedback:str']},
]
elif task_type == 'dependency':
obs['task_description'] = case.get('task_description', '')
obs['code_snippet'] = case.get('code_snippet', '')
subtype = case.get('task_subtype', '')
if subtype == 'migrate':
obs['graph_break_report'] = case.get('graph_break_report', case.get('break_descriptions', []))
obs['available_actions'] = [
{'name': 'migrate_api', 'params': ['completed_items:list', 'code_changes:dict']},
{'name': 'validate_tree', 'params': ['completed_items:list']},
]
elif subtype == 'resolve':
obs['conflict_packages'] = case.get('conflict_packages', [])
obs['compatibility_matrix'] = case.get('compatibility_matrix', {})
obs['available_actions'] = [
{'name': 'resolve_conflict', 'params': ['packages:dict', 'reasoning:str']},
]
else:
obs['available_actions'] = [
{'name': 'flag_outdated',
'params': ['packages:dict', 'deprecated_api:str|null', 'replacement:str|null']},
]
elif task_type == 'clinical':
obs['task_description'] = case.get('task_description', '')
obs['patient_id'] = case.get('patient_id', '')
obs['events'] = case.get('events', case.get('patient_events', []))
obs['available_steps'] = case.get('available_steps', [])
if session.task_id in ('cli_medium', 'cli_hard'):
obs['dependency_graph'] = case.get('dependency_graph', {})
obs['available_actions'] = [
{'name': 'detect_gap', 'params': ['missing_steps:list', 'risk_level:str']},
{'name': 'rank_issues', 'params': ['priority_order:list']},
{'name': 'order_steps', 'params': ['recovery_steps:list']},
]
return obs
def _compute_score_details(action: Dict, session: SessionState) -> Dict[str, float]:
"""Compute per-component score breakdown for UI display."""
atype = action.get('action_type', '')
case = session.task_case
details = {}
if session.task_type == 'security':
if atype == 'identify_vulnerability':
details['vuln_type_match'] = 1.0 if action.get('vuln_type') == case.get('expected_vuln_type') else 0.0
lo, hi = case.get('cvss_range', [0, 10])
try:
v = float(action.get('cvss_score', -1))
details['cvss_in_range'] = 1.0 if lo <= v <= hi else (0.4 if abs(v - (lo + hi) / 2) <= 1.5 else 0.0)
except (TypeError, ValueError):
details['cvss_in_range'] = 0.0
details['severity_match'] = 1.0 if action.get('severity') == case.get('expected_severity') else 0.0
elif atype == 'propose_fix':
tokens = case.get('required_fix_tokens', [])
if isinstance(tokens, dict):
tokens = tokens.get(case.get('expected_vuln_type', ''), [])
tokens = [t for t in tokens if isinstance(t, str)]
fix = action.get('fix_code', '')
details['token_coverage'] = sum(1 for t in tokens if t.lower() in fix.lower()) / max(len(tokens), 1) if fix else 0.0
key_id = case.get('must_preserve_identifier', '')
details['id_preserved'] = 1.0 if key_id and key_id in fix else 0.0
elif atype == 'revise_fix':
kws = case.get('current_feedback_keywords', [])
addressed = action.get('addressed_feedback', '')
details['feedback_addressed'] = sum(1 for kw in kws if kw.lower() in addressed.lower()) / max(len(kws), 1) if addressed else 0.0
elif session.task_type == 'dependency':
if atype == 'flag_outdated':
expected = set(case.get('expected_outdated_packages', []))
provided = set(action.get('packages', {}).keys())
if expected:
tp = len(expected & provided)
p = tp / max(len(provided), 1)
r = tp / max(len(expected), 1)
details['pkg_f1'] = round(2 * p * r / max(p + r, 0.001), 4)
details['api_match'] = 1.0 if action.get('deprecated_api') == case.get('expected_deprecated_api') else 0.0
elif atype == 'resolve_conflict':
proposed = action.get('packages', {})
conflict = case.get('conflict_packages', [])
details['packages_proposed'] = len(proposed)
details['conflict_count'] = len(conflict)
elif atype in ('migrate_api', 'validate_tree'):
checklist = case.get('graph_breaks', [])
completed = action.get('completed_items', [])
details['items_completed'] = len(completed)
details['total_items'] = len(checklist)
elif session.task_type == 'clinical':
if atype == 'detect_gap':
expected = set(case.get('expected_missing_steps', []))
provided = set(action.get('missing_steps', []))
if expected:
tp = len(expected & provided)
p = tp / max(len(provided), 1)
r = tp / max(len(expected), 1)
details['step_f1'] = round(2 * p * r / max(p + r, 0.001), 4)
details['risk_match'] = 1.0 if action.get('risk_level') == case.get('expected_risk') else 0.0
elif atype == 'rank_issues':
expected = case.get('priority_order', [])
provided = action.get('priority_order', [])
details['ranking_overlap'] = len(set(expected) & set(provided)) / max(len(expected), 1) if expected else 0.0
elif atype == 'order_steps':
expected = case.get('required_steps', case.get('expected_missing_steps', []))
provided = action.get('recovery_steps', [])
details['steps_overlap'] = len(set(expected) & set(provided)) / max(len(expected), 1) if expected else 0.0
return details