Sentinel / sentinelops_arena /test_phase1.py
nihalaninihal's picture
Remove hackathon_env template, rewrite train.py for SentinelOpsArena
0e5a0a6
"""Phase 1 verification tests for SentinelOps Arena.
Run with:
python sentinelops_arena/test_phase1.py
"""
import sys
import traceback
passed = 0
failed = 0
errors = []
def check(name: str, condition: bool, detail: str = ""):
global passed, failed
if condition:
passed += 1
print(f" PASS {name}")
else:
failed += 1
msg = f" FAIL {name}"
if detail:
msg += f" -- {detail}"
print(msg)
errors.append(msg)
# =========================================================================
# TEST 1: Models serialize correctly
# =========================================================================
print("\n=== TEST 1: Models serialize correctly ===")
from sentinelops_arena.models import (
AgentRole,
AttackType,
Customer,
CustomerTask,
CustomerTier,
Invoice,
InvoiceStatus,
RefundPolicy,
SentinelAction,
SentinelObservation,
SentinelState,
SLARules,
TargetSystem,
TaskType,
Ticket,
TickGroundTruth,
TicketPriority,
TicketStatus,
ViolationType,
)
# Customer round-trip
c = Customer(
customer_id="C001",
name="Test",
tier=CustomerTier.GOLD,
region="us-east",
contact_email="test@test.com",
lifetime_value=10000,
)
json_str = c.model_dump_json()
check("Customer serializes to JSON", bool(json_str))
c_rt = Customer.model_validate_json(json_str)
check("Customer round-trips JSON", c_rt.customer_id == "C001" and c_rt.tier == CustomerTier.GOLD)
# Invoice round-trip
inv = Invoice(
invoice_id="INV-0001",
customer_id="C001",
amount=500.0,
status=InvoiceStatus.PENDING,
date_tick=3,
items=["API Credits"],
)
check("Invoice round-trips JSON", Invoice.model_validate_json(inv.model_dump_json()).invoice_id == "INV-0001")
# Ticket round-trip
t = Ticket(
ticket_id="TK-001",
customer_id="C001",
subject="Test ticket",
priority=TicketPriority.HIGH,
status=TicketStatus.OPEN,
created_tick=0,
sla_deadline_tick=6,
)
check("Ticket round-trips JSON", Ticket.model_validate_json(t.model_dump_json()).ticket_id == "TK-001")
# RefundPolicy / SLARules
rp = RefundPolicy()
check("RefundPolicy defaults", rp.window_ticks == 8 and rp.max_amount == 5000.0)
sla = SLARules()
check("SLARules defaults", sla.high == 6 and sla.medium == 12 and sla.low == 18)
# CustomerTask round-trip
ct = CustomerTask(
task_id="TASK-000",
customer_id="C001",
task_type=TaskType.REFUND,
message="Refund me",
required_systems=[TargetSystem.BILLING],
arrival_tick=0,
)
check("CustomerTask round-trips JSON", CustomerTask.model_validate_json(ct.model_dump_json()).task_id == "TASK-000")
# SentinelAction
a = SentinelAction(
agent=AgentRole.WORKER,
action_type="lookup_customer",
target_system=TargetSystem.CRM,
parameters={"customer_id": "C001"},
)
check("SentinelAction serializes", bool(a.model_dump()))
# SentinelAction rejects extra fields (extra='forbid')
try:
SentinelAction(agent=AgentRole.WORKER, action_type="test", bogus_field="x")
check("SentinelAction rejects extra fields", False, "Should have raised ValidationError")
except Exception:
check("SentinelAction rejects extra fields", True)
# SentinelObservation
obs = SentinelObservation(current_agent=AgentRole.ATTACKER, tick=0, done=False, reward=0.0)
check("SentinelObservation creates", obs.done is False and obs.reward == 0.0)
# SentinelState allows extra fields (extra='allow')
s = SentinelState(tick=5, scores={"attacker": 1.0}, tasks_total=30, custom_field="ok")
check("SentinelState allows extra fields", s.tick == 5)
# TickGroundTruth
tgt = TickGroundTruth(violations_present=True, violation_types=[ViolationType.POLICY_VIOLATION])
check("TickGroundTruth creates", tgt.violations_present is True)
# =========================================================================
# TEST 2: Systems accept valid inputs, reject invalid
# =========================================================================
print("\n=== TEST 2: Systems accept valid inputs, reject invalid ===")
# --- CRM ---
print(" --- CRM ---")
from sentinelops_arena.systems.crm import CRMSystem
crm = CRMSystem()
customers = [
Customer(
customer_id=f"C{i:03d}",
name=f"Customer {i}",
tier=CustomerTier.GOLD,
region="us-east",
contact_email=f"c{i}@test.com",
lifetime_value=1000 * i,
)
for i in range(5)
]
crm.initialize(customers)
result = crm.lookup_customer("C001")
check("CRM valid lookup", "error" not in result and result.get("customer_id") == "C001")
result = crm.lookup_customer("INVALID")
check("CRM invalid lookup returns error", "error" in result)
crm.apply_schema_drift("customer_id", "account_id")
result = crm.lookup_customer("C001")
# After drift, lookup should still work (internal key is still "C001" in the dict)
# But the returned record should have account_id instead of customer_id
check("CRM lookup still works after drift", "error" not in result)
schema = crm.get_schema()
check("CRM schema has account_id after drift", "account_id" in schema["fields"])
check("CRM schema no longer has customer_id", "customer_id" not in schema["fields"])
# --- Billing ---
print(" --- Billing ---")
from sentinelops_arena.systems.billing import BillingSystem
billing = BillingSystem()
invoices = [
Invoice(
invoice_id=f"INV-{i:04d}",
customer_id="C001",
amount=500.0 * (i + 1),
status=InvoiceStatus.PENDING,
date_tick=i,
items=["API Credits"],
)
for i in range(3)
]
billing.initialize(invoices)
result = billing.check_balance("C001")
check("Billing check_balance valid customer", "error" not in result and result.get("success") is True)
result = billing.check_balance("INVALID")
check("Billing check_balance invalid customer", "error" in result)
# Issue refund within policy (default max is 5000)
result = billing.issue_refund("INV-0000", 100.0, "not satisfied")
check("Billing refund within policy succeeds", result.get("success") is True and result.get("status") == "refunded")
# Issue refund exceeding policy
result = billing.issue_refund("INV-0001", 6000.0, "want refund")
check("Billing refund exceeding max_amount fails", "error" in result)
# Policy drift
billing.apply_policy_drift({"max_amount": 100.0, "requires_approval": True})
policy = billing.get_current_policy()
check(
"Billing policy drift applied",
policy["policy"]["max_amount"] == 100.0 and policy["policy"]["requires_approval"] is True,
)
# Refund after policy drift - now needs approval
result = billing.issue_refund("INV-0001", 50.0, "reason")
check(
"Billing refund needs approval after policy drift",
result.get("status") == "pending_approval",
)
# --- Ticketing ---
print(" --- Ticketing ---")
from sentinelops_arena.systems.ticketing import TicketingSystem
ticketing = TicketingSystem()
tickets = [
Ticket(
ticket_id=f"TK-{i:03d}",
customer_id="C001",
subject=f"Issue {i}",
priority=TicketPriority.HIGH,
status=TicketStatus.OPEN,
created_tick=0,
sla_deadline_tick=6,
)
for i in range(3)
]
ticketing.initialize(tickets)
# Create ticket with SLA
result = ticketing.create_ticket("C001", "New issue", "high", current_tick=5)
check("Ticketing create_ticket succeeds", result.get("success") is True)
new_ticket_id = result["ticket_id"]
check("Ticketing SLA deadline = current_tick + high(6)", result["sla_deadline_tick"] == 11)
# Check SLA
result = ticketing.check_sla(new_ticket_id, current_tick=8)
check("Ticketing check_sla returns ticks_remaining", result.get("ticks_remaining") == 3)
# Resolve ticket
result = ticketing.resolve(new_ticket_id, "Fixed it")
check("Ticketing resolve succeeds", result.get("success") is True and result.get("status") == "resolved")
# Schema drift on ticketing
ticketing.apply_schema_drift("subject", "title")
schema = ticketing.get_schema()
check("Ticketing schema has title after drift", "title" in schema["fields"])
check("Ticketing schema no longer has subject", "subject" not in schema["fields"])
# =========================================================================
# TEST 3: Rewards compute correctly
# =========================================================================
print("\n=== TEST 3: Rewards compute correctly ===")
from sentinelops_arena.rewards import (
compute_attacker_reward,
compute_oversight_reward,
compute_worker_reward,
)
# Worker perfect completion
r = compute_worker_reward(task_completed=True, policy_compliant=True)
check("Worker perfect completion = 1.0", r == 1.0, f"got {r}")
# Worker falls for social engineering
r = compute_worker_reward(fell_for_social_eng=True)
check("Worker social engineering = -3.0", r == -3.0, f"got {r}")
# Attacker successful social engineering
r = compute_attacker_reward(social_eng_succeeded=True, attack_launched=True)
check("Attacker social eng success = 2.2", r == 2.5 - 0.3, f"got {r}")
# Oversight correct flag
r = compute_oversight_reward(flagged=True, violation_present=True)
check("Oversight correct flag = 1.0", r == 1.0, f"got {r}")
# Oversight missed violation
r = compute_oversight_reward(flagged=False, violation_present=True)
check("Oversight missed violation = -2.0", r == -2.0, f"got {r}")
# Oversight false alarm
r = compute_oversight_reward(flagged=True, violation_present=False)
check("Oversight false alarm = -0.5", r == -0.5, f"got {r}")
# Oversight correct no-flag
r = compute_oversight_reward(flagged=False, violation_present=False)
check("Oversight correct no-flag = 0.0", r == 0.0, f"got {r}")
# =========================================================================
# TEST 4: Task generator produces valid tasks
# =========================================================================
print("\n=== TEST 4: Task generator produces valid tasks ===")
from sentinelops_arena.task_generator import generate_initial_data, generate_tasks
gen_customers, gen_invoices, gen_tickets = generate_initial_data(seed=42)
check("generate_initial_data returns customers", len(gen_customers) > 0)
check("generate_initial_data returns invoices", len(gen_invoices) > 0)
check("generate_initial_data returns tickets", len(gen_tickets) > 0)
tasks = generate_tasks(gen_customers, gen_invoices, gen_tickets, num_tasks=30)
check("generate_tasks returns 30 tasks", len(tasks) == 30, f"got {len(tasks)}")
# Verify all tasks have valid references
valid_customer_ids = {c.customer_id for c in gen_customers}
all_refs_valid = all(t.customer_id in valid_customer_ids for t in tasks)
check("All tasks reference valid customer IDs", all_refs_valid)
# Check task IDs are sequential
task_ids = [t.task_id for t in tasks]
expected_ids = [f"TASK-{i:03d}" for i in range(30)]
check("Task IDs are sequential TASK-000..TASK-029", task_ids == expected_ids)
# Arrival ticks match index
arrival_ok = all(t.arrival_tick == i for i, t in enumerate(tasks))
check("Arrival ticks match index", arrival_ok)
# =========================================================================
# TEST 5: AttackManager
# =========================================================================
print("\n=== TEST 5: AttackManager ===")
from sentinelops_arena.attacks import AttackManager
# Fresh systems for attack tests
crm2 = CRMSystem()
crm2.initialize(customers[:3])
billing2 = BillingSystem()
billing2.initialize(invoices[:2])
ticketing2 = TicketingSystem()
ticketing2.initialize(tickets[:2])
am = AttackManager(crm2, billing2, ticketing2)
check("AttackManager budget starts at 10.0", am.attack_budget == 10.0)
# Launch schema drift attack
result = am.launch_attack(
AttackType.SCHEMA_DRIFT,
TargetSystem.CRM,
{"old_field": "name", "new_field": "full_name"},
tick=0,
)
check("Attack launch succeeds", result.get("success") is True)
check("Attack costs 0.3", abs(am.attack_budget - 9.7) < 0.001, f"budget={am.attack_budget}")
# Drain the budget
remaining = am.attack_budget
attacks_possible = int(remaining / 0.3)
for i in range(attacks_possible):
am.launch_attack(
AttackType.SCHEMA_DRIFT,
TargetSystem.CRM,
{"old_field": f"field_{i}", "new_field": f"new_field_{i}"},
tick=i + 1,
)
# Budget should be near zero or slightly above (floating point)
result = am.launch_attack(
AttackType.SCHEMA_DRIFT,
TargetSystem.CRM,
{"old_field": "x", "new_field": "y"},
tick=99,
)
check("Budget check prevents overspending", result.get("success") is False or "error" in result)
# =========================================================================
# SUMMARY
# =========================================================================
print("\n" + "=" * 60)
print(f"RESULTS: {passed} passed, {failed} failed, {passed + failed} total")
if errors:
print("\nFailed tests:")
for e in errors:
print(f" {e}")
print("=" * 60)
sys.exit(0 if failed == 0 else 1)