skill-invocation-env / test_env.py
mpnikhil's picture
Upload folder using huggingface_hub
ac627d5 verified
#!/usr/bin/env python3
"""
Local test script for the Skill Invocation Environment.
Tests the environment directly (no server) to verify:
- reset() works and returns proper observation
- list/load/unload/submit actions work correctly
- context budget enforcement
- precision/recall/bloat reward computation
- verifier tests for static and procedural tasks
"""
import sys
import os
# Add parent dir so imports work
sys.path.insert(0, os.path.dirname(__file__))
from models import SkillInvocationAction, SkillInvocationObservation, SkillInvocationState
from task_bank import TASK_BANK, SKILL_BANK
from server.skill_invocation_env_environment import SkillInvocationEnvironment
from task_generator import TaskGenerator
# ---------------------------------------------------------------------------
# Core environment tests
# ---------------------------------------------------------------------------
def test_reset():
"""Test that reset returns a valid observation."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
assert isinstance(obs, SkillInvocationObservation)
assert obs.task_description != ""
assert len(obs.skill_catalog) >= 5 # relevant + distractors (now 5-8)
assert obs.done is False
assert obs.reward == 0.0
assert obs.skill_content is None
assert obs.loaded_skills == []
assert obs.context_budget_used == 0
assert obs.context_budget_total == 5
assert len(obs.messages) > 0
print("[PASS] test_reset")
def test_load_skill():
"""Test loading a skill puts it in context."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
skill_id = obs.skill_catalog[0]["id"]
action = SkillInvocationAction(action_type="load", skill_id=skill_id)
obs2 = env.step(action)
assert obs2.skill_content is not None
assert len(obs2.skill_content) > 0
assert skill_id in obs2.loaded_skills
assert obs2.context_budget_used == 1
assert skill_id in obs2.loaded_skill_contents
assert obs2.done is False
print("[PASS] test_load_skill")
def test_invoke_backward_compat():
"""Test that 'invoke' still works as alias for 'load'."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
skill_id = obs.skill_catalog[0]["id"]
action = SkillInvocationAction(action_type="invoke", skill_id=skill_id)
obs2 = env.step(action)
assert obs2.skill_content is not None
assert skill_id in obs2.loaded_skills
assert obs2.context_budget_used == 1
print("[PASS] test_invoke_backward_compat")
def test_unload_skill():
"""Test unloading a skill removes it from context."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
skill_id = obs.skill_catalog[0]["id"]
# Load
env.step(SkillInvocationAction(action_type="load", skill_id=skill_id))
# Unload
obs3 = env.step(SkillInvocationAction(action_type="unload", skill_id=skill_id))
assert skill_id not in obs3.loaded_skills
assert obs3.context_budget_used == 0
assert obs3.skill_content is None
# Should still be in skills_ever_loaded (history)
assert skill_id in obs3.skills_invoked
print("[PASS] test_unload_skill")
def test_load_already_loaded():
"""Loading same skill twice is a no-op (no double counting)."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
skill_id = obs.skill_catalog[0]["id"]
env.step(SkillInvocationAction(action_type="load", skill_id=skill_id))
obs2 = env.step(SkillInvocationAction(action_type="load", skill_id=skill_id))
assert obs2.context_budget_used == 1 # Not 2
assert obs2.loaded_skills.count(skill_id) == 1
assert obs2.skill_content is not None # Still returns content
print("[PASS] test_load_already_loaded")
def test_unload_not_loaded():
"""Unloading a skill that isn't loaded is a no-op."""
env = SkillInvocationEnvironment()
env.reset(seed=42)
obs = env.step(SkillInvocationAction(action_type="unload", skill_id="skill_001"))
assert obs.context_budget_used == 0
print("[PASS] test_unload_not_loaded")
def test_context_budget():
"""Test that context budget is enforced."""
env = SkillInvocationEnvironment(context_budget=3)
obs = env.reset(seed=42)
catalog_ids = [s["id"] for s in obs.skill_catalog]
# Load 3 skills (budget full)
for i in range(min(3, len(catalog_ids))):
env.step(SkillInvocationAction(action_type="load", skill_id=catalog_ids[i]))
obs = env.step(SkillInvocationAction(action_type="load", skill_id=catalog_ids[3]))
# Should fail — budget is full
assert obs.context_budget_used == 3
assert catalog_ids[3] not in obs.loaded_skills
# Unload one, then load should work
env.step(SkillInvocationAction(action_type="unload", skill_id=catalog_ids[0]))
obs2 = env.step(SkillInvocationAction(action_type="load", skill_id=catalog_ids[3]))
assert catalog_ids[3] in obs2.loaded_skills
assert obs2.context_budget_used == 3
print("[PASS] test_context_budget")
def test_load_unknown_skill():
"""Test loading a skill not in the catalog."""
env = SkillInvocationEnvironment()
env.reset(seed=42)
action = SkillInvocationAction(action_type="load", skill_id="skill_999")
obs = env.step(action)
assert obs.skill_content is None
assert obs.context_budget_used == 0
print("[PASS] test_load_unknown_skill")
def test_submit_incorrect():
"""Test submitting an incorrect answer."""
env = SkillInvocationEnvironment()
env.reset(seed=42)
action = SkillInvocationAction(action_type="submit", answer="I don't know")
obs = env.step(action)
assert obs.done is True
assert obs.reward <= 0.0
assert obs.verification_result is not None
assert "INCORRECT" in obs.verification_result
print("[PASS] test_submit_incorrect")
def test_submit_after_done():
"""Test that actions after done return done state."""
env = SkillInvocationEnvironment()
env.reset(seed=42)
env.step(SkillInvocationAction(action_type="submit", answer="test"))
obs = env.step(SkillInvocationAction(action_type="load", skill_id="skill_001"))
assert obs.done is True
print("[PASS] test_submit_after_done")
def test_precision_reward():
"""Load only relevant skill, submit correct answer → max reward 1.0."""
env = SkillInvocationEnvironment()
for seed in range(100):
obs = env.reset(seed=seed)
state = env.state
if state.task_id == "task_001":
break
else:
print("[SKIP] test_precision_reward - couldn't find task_001")
return
# Load only relevant skill
env.step(SkillInvocationAction(action_type="load", skill_id="skill_001"))
correct_answer = """
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.sha256).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
"""
obs = env.step(SkillInvocationAction(action_type="submit", answer=correct_answer))
assert obs.done is True
assert "CORRECT" in obs.verification_result
# 0.6 correctness + 0.3 precision (1/1) + 0.1 recall (1/1) = 1.0
assert abs(obs.reward - 1.0) < 0.01, f"Expected ~1.0, got {obs.reward}"
print(f"[PASS] test_precision_reward (reward={obs.reward})")
def test_bloat_penalty():
"""Load all catalog skills, submit correct answer → reduced reward."""
env = SkillInvocationEnvironment()
for seed in range(100):
obs = env.reset(seed=seed)
state = env.state
if state.task_id == "task_001":
break
else:
print("[SKIP] test_bloat_penalty - couldn't find task_001")
return
# Load all catalog skills (relevant + distractors)
for skill in obs.skill_catalog:
env.step(SkillInvocationAction(action_type="load", skill_id=skill["id"]))
correct_answer = """
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.sha256).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
"""
obs = env.step(SkillInvocationAction(action_type="submit", answer=correct_answer))
assert obs.done is True
assert "CORRECT" in obs.verification_result
# With 6 total skills loaded (1 relevant + 5 distractors):
# 0.6 + 0.3*(1/6) + 0.1*(1/1) - 0.15*5 = 0.6 + 0.05 + 0.1 - 0.75 = 0.0
# Reward should be much less than 1.0
assert obs.reward < 0.5, f"Bloat should reduce reward, got {obs.reward}"
print(f"[PASS] test_bloat_penalty (reward={obs.reward})")
def test_load_unload_no_bloat():
"""Load distractor, unload before submit → no bloat penalty."""
env = SkillInvocationEnvironment()
for seed in range(100):
obs = env.reset(seed=seed)
state = env.state
if state.task_id == "task_001":
break
else:
print("[SKIP] test_load_unload_no_bloat - couldn't find task_001")
return
# Load a distractor
distractor_id = None
for skill in obs.skill_catalog:
if skill["id"] != "skill_001":
distractor_id = skill["id"]
break
assert distractor_id is not None
env.step(SkillInvocationAction(action_type="load", skill_id=distractor_id))
# Unload it
env.step(SkillInvocationAction(action_type="unload", skill_id=distractor_id))
# Load relevant
env.step(SkillInvocationAction(action_type="load", skill_id="skill_001"))
correct_answer = """
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.sha256).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
"""
obs = env.step(SkillInvocationAction(action_type="submit", answer=correct_answer))
assert obs.done is True
assert "CORRECT" in obs.verification_result
# Only skill_001 loaded at submit → no bloat
# 0.6 + 0.3 + 0.1 = 1.0
assert abs(obs.reward - 1.0) < 0.01, f"Expected ~1.0 after unload, got {obs.reward}"
print(f"[PASS] test_load_unload_no_bloat (reward={obs.reward})")
def test_state_property():
"""Test that state returns correct metadata."""
env = SkillInvocationEnvironment()
obs = env.reset(seed=42)
state = env.state
assert isinstance(state, SkillInvocationState)
assert state.episode_id is not None
assert state.step_count == 0
assert state.task_id != ""
assert state.done is False
assert state.loaded_skills == []
assert state.context_budget_total == 5
# After a step
skill_id = obs.skill_catalog[0]["id"]
env.step(SkillInvocationAction(action_type="load", skill_id=skill_id))
state = env.state
assert state.step_count == 1
assert skill_id in state.loaded_skills
print("[PASS] test_state_property")
def test_all_tasks_have_valid_skills():
"""Verify task bank integrity."""
for task in TASK_BANK:
for sid in task["relevant_skills"]:
assert sid in SKILL_BANK, f"Task {task['id']}: missing relevant skill {sid}"
for sid in task["distractor_skills"]:
assert sid in SKILL_BANK, f"Task {task['id']}: missing distractor skill {sid}"
# Verify no overlap between relevant and distractor
overlap = set(task["relevant_skills"]) & set(task["distractor_skills"])
assert len(overlap) == 0, f"Task {task['id']}: overlap: {overlap}"
# Each task now should have at least 5 skills in catalog
total = len(task["relevant_skills"]) + len(task["distractor_skills"])
assert total >= 5, f"Task {task['id']}: only {total} skills in catalog"
print(f"[PASS] test_all_tasks_have_valid_skills ({len(TASK_BANK)} tasks verified)")
# ---------------------------------------------------------------------------
# Verifier tests (unchanged — these test verifier correctness, not env logic)
# ---------------------------------------------------------------------------
def test_verifier_task001_correct_code_passes():
"""Verify task_001 exec verifier passes reference implementation from skill content."""
task = next(t for t in TASK_BANK if t["id"] == "task_001")
correct_code = '''
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.sha256).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
'''
assert task["verifier"](correct_code), "Reference implementation should pass"
print("[PASS] test_verifier_task001_correct_code_passes")
def test_verifier_task001_keywords_only_fails():
"""Verify task_001 exec verifier rejects keyword-stuffed garbage."""
task = next(t for t in TASK_BANK if t["id"] == "task_001")
garbage = "hmac sha256 x-zephyr-auth base64 zph encode_zephyr_auth"
assert not task["verifier"](garbage), "Keyword-stuffed garbage should fail"
print("[PASS] test_verifier_task001_keywords_only_fails")
def test_verifier_task001_wrong_format_fails():
"""Verify task_001 rejects code with wrong header format."""
task = next(t for t in TASK_BANK if t["id"] == "task_001")
wrong_code = '''
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.md5).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
'''
assert not task["verifier"](wrong_code), "Wrong hash algorithm should fail"
print("[PASS] test_verifier_task001_wrong_format_fails")
def test_verifier_task001_markdown_fenced():
"""Verify task_001 exec verifier handles markdown-fenced code."""
task = next(t for t in TASK_BANK if t["id"] == "task_001")
fenced = '''```python
import hmac, hashlib, base64
def encode_zephyr_auth(api_key: str, timestamp: int) -> dict:
signing_string = f"{api_key}:{timestamp}"
digest = hmac.new(api_key.encode(), signing_string.encode(), hashlib.sha256).digest()
b64 = base64.b64encode(digest).decode()
return {"X-Zephyr-Auth": f"ZPH {api_key}:{b64}:{timestamp}"}
```'''
assert task["verifier"](fenced), "Markdown-fenced correct code should pass"
print("[PASS] test_verifier_task001_markdown_fenced")
def test_verifier_task002_correct_passes():
"""Verify task_002 NovaBin header parser passes with correct implementation."""
task = next(t for t in TASK_BANK if t["id"] == "task_002")
correct_code = '''
import struct
def parse_novabin_header(data: bytes) -> dict:
magic = data[0:4]
assert magic == b'NOVB', f"Invalid magic: {magic}"
version = struct.unpack('>H', data[4:6])[0]
record_count = struct.unpack('>I', data[6:10])[0]
flags = struct.unpack('>H', data[10:12])[0]
checksum = struct.unpack('>I', data[12:16])[0]
return {
"version": version, "record_count": record_count,
"compressed": bool(flags & 1), "encrypted": bool(flags & 2),
"checksummed": bool(flags & 4), "checksum": checksum
}
'''
assert task["verifier"](correct_code), "Correct NovaBin parser should pass"
print("[PASS] test_verifier_task002_correct_passes")
def test_verifier_task002_keywords_only_fails():
"""Verify task_002 rejects keyword-stuffed answer."""
task = next(t for t in TASK_BANK if t["id"] == "task_002")
garbage = "struct NOVB 0x4E4F5642 big-endian parse_novabin_header version record_count"
assert not task["verifier"](garbage), "Keyword-stuffed answer should fail"
print("[PASS] test_verifier_task002_keywords_only_fails")
def test_verifier_task003_structural():
"""Verify task_003 HelixLang structural verifier catches structure, not just keywords."""
task = next(t for t in TASK_BANK if t["id"] == "task_003")
good = '''
fn fetch_user(db: Database, user_id: str) -> result<User> {
let conn = try! db.connect().with_context("step", "connecting to database")
let user = match try! conn.query_user(user_id).with_context("step", "fetching user") {
Ok(u) => u,
Err(e) => {
if e.retryable {
return retry_with_backoff(|| conn.query_user(user_id), max=3, backoff=100ms)
}
helix.log.error(e)
return Err(HelixError.wrap(e, "HLX-DATA-2001", "user fetch failed"))
}
}
Ok(user)
}
'''
assert task["verifier"](good), "Proper HelixLang pseudocode should pass"
keywords_only = "HLX-DATA try! with_context retry backoff helix.log.error result Ok Err"
assert not task["verifier"](keywords_only), "Keywords without structure should fail"
print("[PASS] test_verifier_task003_structural")
def test_verifier_task004_yaml_structure():
"""Verify task_004 ArcDeploy YAML verifier checks structure."""
task = next(t for t in TASK_BANK if t["id"] == "task_004")
good_yaml = '''```yaml
canary:
phases:
- name: shadow
traffic_pct: 0
duration_min: 5
metrics_gate: error_rate < 0.01
- name: canary_1
traffic_pct: 5
duration_min: 10
metrics_gate: p99_latency_ms < 200 AND error_rate < 0.005
- name: canary_2
traffic_pct: 25
duration_min: 15
metrics_gate: p99_latency_ms < 250 AND error_rate < 0.005
- name: canary_3
traffic_pct: 50
duration_min: 20
metrics_gate: p99_latency_ms < 300 AND error_rate < 0.01
- name: full
traffic_pct: 100
duration_min: 0
rollback:
auto: true
on_metric_breach: immediate
cooldown_min: 30
```'''
assert task["verifier"](good_yaml), "Valid ArcDeploy YAML should pass"
keywords = "shadow canary_1 traffic_pct metrics_gate error_rate rollback auto: true"
assert not task["verifier"](keywords), "Keywords-only should fail YAML verifier"
print("[PASS] test_verifier_task004_yaml_structure")
def test_verifier_task008_record_parser():
"""Verify task_008 NovaBin record parser with exec verifier."""
task = next(t for t in TASK_BANK if t["id"] == "task_008")
correct_code = '''
import struct
def parse_novabin_record(data: bytes, offset: int) -> tuple:
fields = {}
field_count = struct.unpack('>H', data[offset:offset+2])[0]
offset += 2
for _ in range(field_count):
type_tag = data[offset]
offset += 1
name_len = struct.unpack('>H', data[offset:offset+2])[0]
offset += 2
field_name = data[offset:offset+name_len].decode('utf-8')
offset += name_len
val_len = struct.unpack('>I', data[offset:offset+4])[0]
offset += 4
val_data = data[offset:offset+val_len]
offset += val_len
if type_tag == 0x01: # int32
fields[field_name] = struct.unpack('>i', val_data)[0]
elif type_tag == 0x02: # float64
fields[field_name] = struct.unpack('>d', val_data)[0]
elif type_tag == 0x03: # string
fields[field_name] = val_data.decode('utf-8')
elif type_tag == 0x04: # bool
fields[field_name] = val_data[0] != 0
return (fields, offset)
'''
assert task["verifier"](correct_code), "Correct record parser should pass"
keywords = "struct 0x01 0x02 0x03 0x04 uint16 utf-8 parse_novabin_record"
assert not task["verifier"](keywords), "Keywords should fail exec verifier"
print("[PASS] test_verifier_task008_record_parser")
# ---------------------------------------------------------------------------
# SkillsBench-adapted task tests
# ---------------------------------------------------------------------------
def test_sb_001_flood_detection_correct():
"""Verify task_sb_001 flood detection passes with correct implementation."""
task = next(t for t in TASK_BANK if t["id"] == "task_sb_001")
correct_code = '''
def detect_flood_days(daily_max_levels, flood_thresholds):
result = {}
for station_id, levels in daily_max_levels.items():
if station_id not in flood_thresholds:
continue
threshold = flood_thresholds[station_id]
flood_days = sum(1 for level in levels if level >= threshold)
if flood_days > 0:
result[station_id] = flood_days
return result
'''
assert task["verifier"](correct_code), "Correct flood detection should pass"
garbage = "detect_flood_days daily_max_levels flood_thresholds threshold"
assert not task["verifier"](garbage), "Keywords should fail"
print("[PASS] test_sb_001_flood_detection_correct")
def test_sb_002_hp_filter_correct():
"""Verify task_sb_002 HP filter correlation passes with correct implementation."""
try:
import numpy # noqa: F401
from statsmodels.tsa.filters.hp_filter import hpfilter # noqa: F401
except ImportError:
print("[SKIP] test_sb_002_hp_filter_correct - scipy/statsmodels not installed")
return
task = next(t for t in TASK_BANK if t["id"] == "task_sb_002")
correct_code = '''
import numpy as np
from statsmodels.tsa.filters.hp_filter import hpfilter
def hp_filter_correlation(series_a, series_b):
log_a = np.log(series_a)
log_b = np.log(series_b)
cycle_a, _ = hpfilter(log_a, lamb=100)
cycle_b, _ = hpfilter(log_b, lamb=100)
corr = np.corrcoef(cycle_a, cycle_b)[0, 1]
return round(float(corr), 5)
'''
assert task["verifier"](correct_code), "Correct HP filter implementation should pass"
garbage = "hp_filter_correlation numpy hpfilter corrcoef lamb=100"
assert not task["verifier"](garbage), "Keywords should fail"
print("[PASS] test_sb_002_hp_filter_correct")
def test_sb_003_dialogue_parser_correct():
"""Verify task_sb_003 dialogue parser passes with correct implementation."""
task = next(t for t in TASK_BANK if t["id"] == "task_sb_003")
correct_code = (
'import re\n'
'\n'
'def parse_dialogue(script):\n'
' nodes = []\n'
' edges = []\n'
' lines = script.strip().split("\\n")\n'
' current_node_id = None\n'
' current_lines = []\n'
' def flush_node():\n'
' nonlocal current_node_id, current_lines\n'
' if current_node_id is None:\n'
' return\n'
' content_lines = [l.strip() for l in current_lines if l.strip()]\n'
' is_choice = any(re.match(r"^\\d+\\.", l) for l in content_lines)\n'
' if is_choice:\n'
' nodes.append({"id": current_node_id, "text": "", "speaker": "", "type": "choice"})\n'
' for l in content_lines:\n'
' m = re.match(r"^(\\d+\\.\\s*.+?)\\s*->\\s*(\\w+)$", l)\n'
' if m:\n'
' edges.append({"from": current_node_id, "to": m.group(2), "text": m.group(1).strip()})\n'
' else:\n'
' speaker = ""\n'
' text = ""\n'
' target = None\n'
' for l in content_lines:\n'
' m = re.match(r"^(\\w[\\w\\s]*):\\s*(.+?)\\s*->\\s*(\\w+)$", l)\n'
' if m:\n'
' speaker = m.group(1)\n'
' text = m.group(2).strip()\n'
' target = m.group(3)\n'
' else:\n'
' m2 = re.match(r"^(\\w[\\w\\s]*):\\s*(.+)$", l)\n'
' if m2:\n'
' speaker = m2.group(1)\n'
' text = m2.group(2).strip()\n'
' nodes.append({"id": current_node_id, "text": text, "speaker": speaker, "type": "line"})\n'
' if target:\n'
' edges.append({"from": current_node_id, "to": target, "text": ""})\n'
' current_node_id = None\n'
' current_lines = []\n'
' for line in lines:\n'
' m = re.match(r"^\\[(\\w+)\\]$", line.strip())\n'
' if m:\n'
' flush_node()\n'
' current_node_id = m.group(1)\n'
' current_lines = []\n'
' else:\n'
' current_lines.append(line)\n'
' flush_node()\n'
' return {"nodes": nodes, "edges": edges}\n'
)
assert task["verifier"](correct_code), "Correct dialogue parser should pass"
garbage = "parse_dialogue nodes edges from to text speaker type"
assert not task["verifier"](garbage), "Keywords should fail"
print("[PASS] test_sb_003_dialogue_parser_correct")
# ---------------------------------------------------------------------------
# Procedural task generator tests
# ---------------------------------------------------------------------------
def test_procedural_auth_100_seeds():
"""Test auth protocol template produces valid, verifiable tasks for 100 seeds."""
gen = TaskGenerator(seed=0)
for seed in range(100):
result = gen.generate_with_seed(seed, template="auth_protocol")
task = result["task"]
skills = result["skills"]
assert task["id"].startswith("task_proc_auth_")
assert task["source"] == "procedural"
assert task["template"] == "auth_protocol"
assert len(task["relevant_skills"]) == 1
assert len(task["distractor_skills"]) >= 4
for sid in task["relevant_skills"] + task["distractor_skills"]:
assert sid in skills, f"Skill {sid} not in generated skills for seed {seed}"
rel_skill = skills[task["relevant_skills"][0]]
assert len(rel_skill["full_content"]) > 100
print("[PASS] test_procedural_auth_100_seeds")
def test_procedural_binary_100_seeds():
"""Test binary format template produces valid tasks for 100 seeds."""
gen = TaskGenerator(seed=0)
for seed in range(100):
result = gen.generate_with_seed(seed, template="binary_format")
task = result["task"]
skills = result["skills"]
assert task["id"].startswith("task_proc_bin_")
assert task["source"] == "procedural"
assert len(task["relevant_skills"]) == 1
assert len(task["distractor_skills"]) >= 4
for sid in task["relevant_skills"] + task["distractor_skills"]:
assert sid in skills
print("[PASS] test_procedural_binary_100_seeds")
def test_procedural_deterministic():
"""Same seed produces identical tasks."""
gen = TaskGenerator(seed=0)
r1 = gen.generate_with_seed(42, template="auth_protocol")
r2 = gen.generate_with_seed(42, template="auth_protocol")
assert r1["task"]["id"] == r2["task"]["id"]
assert r1["task"]["description"] == r2["task"]["description"]
assert r1["task"]["relevant_skills"] == r2["task"]["relevant_skills"]
assert r1["task"]["distractor_skills"] == r2["task"]["distractor_skills"]
r3 = gen.generate_with_seed(42, template="binary_format")
r4 = gen.generate_with_seed(42, template="binary_format")
assert r3["task"]["id"] == r4["task"]["id"]
assert r3["task"]["description"] == r4["task"]["description"]
print("[PASS] test_procedural_deterministic")
def test_procedural_keyword_stuffing_fails():
"""Keyword-stuffed answers should fail procedural verifiers."""
gen = TaskGenerator(seed=0)
for seed in range(10):
result = gen.generate_with_seed(seed, template="auth_protocol")
task = result["task"]
garbage = "HMAC SHA256 base64 signing API key authentication header"
assert not task["verifier"](garbage), f"Keyword stuffing passed for auth seed {seed}"
result = gen.generate_with_seed(seed, template="binary_format")
task = result["task"]
garbage = "struct unpack CRC32 magic bytes header version flags"
assert not task["verifier"](garbage), f"Keyword stuffing passed for binary seed {seed}"
print("[PASS] test_procedural_keyword_stuffing_fails")
def test_procedural_env_integration():
"""Test environment works with use_procedural=True."""
env = SkillInvocationEnvironment(use_procedural=True, procedural_seed=42)
obs = env.reset(seed=100)
assert isinstance(obs, SkillInvocationObservation)
assert obs.task_description != ""
assert len(obs.skill_catalog) >= 5
assert obs.context_budget_total == 5
assert obs.done is False
skill_id = obs.skill_catalog[0]["id"]
obs2 = env.step(SkillInvocationAction(action_type="load", skill_id=skill_id))
assert obs2.skill_content is not None
assert len(obs2.skill_content) > 50
assert obs2.context_budget_used == 1
obs3 = env.step(SkillInvocationAction(action_type="submit", answer="test"))
assert obs3.done is True
assert obs3.reward is not None
print("[PASS] test_procedural_env_integration")
def test_procedural_uniqueness():
"""Different seeds produce different tasks."""
gen = TaskGenerator(seed=0)
descriptions = set()
for seed in range(50):
result = gen.generate_with_seed(seed, template="auth_protocol")
descriptions.add(result["task"]["description"])
assert len(descriptions) >= 10, f"Only {len(descriptions)} unique tasks from 50 seeds"
print("[PASS] test_procedural_uniqueness")
if __name__ == "__main__":
print("=" * 60)
print("Skill Invocation Environment - Local Tests")
print("=" * 60)
tests = [
# Core environment tests
test_reset,
test_load_skill,
test_invoke_backward_compat,
test_unload_skill,
test_load_already_loaded,
test_unload_not_loaded,
test_context_budget,
test_load_unknown_skill,
test_submit_incorrect,
test_submit_after_done,
test_precision_reward,
test_bloat_penalty,
test_load_unload_no_bloat,
test_state_property,
test_all_tasks_have_valid_skills,
# Verifier tests
test_verifier_task001_correct_code_passes,
test_verifier_task001_keywords_only_fails,
test_verifier_task001_wrong_format_fails,
test_verifier_task001_markdown_fenced,
test_verifier_task002_correct_passes,
test_verifier_task002_keywords_only_fails,
test_verifier_task003_structural,
test_verifier_task004_yaml_structure,
test_verifier_task008_record_parser,
# SkillsBench tests
test_sb_001_flood_detection_correct,
test_sb_002_hp_filter_correct,
test_sb_003_dialogue_parser_correct,
# Procedural generator tests
test_procedural_auth_100_seeds,
test_procedural_binary_100_seeds,
test_procedural_deterministic,
test_procedural_keyword_stuffing_fails,
test_procedural_env_integration,
test_procedural_uniqueness,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"[FAIL] {test.__name__}: {e}")
import traceback
traceback.print_exc()
failed += 1
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
sys.exit(1 if failed > 0 else 0)