"""Comprehensive accuracy and parameter validation for the Firewall environment. Runs all tasks x all policies and validates: 1. New parameter effects (noise, stealth, bursts, false flags, TTL, escalation) 2. Grading accuracy across difficulty tiers 3. Policy ordering: heuristic > block-all > random 4. Feature distribution separation (benign vs malicious) 5. Environment invariants (budget, observations, session lifecycle) """ from __future__ import annotations import os import sys import time from pathlib import Path # Force UTF-8 output on Windows os.environ.setdefault("PYTHONIOENCODING", "utf-8") if sys.stdout.encoding != "utf-8": try: sys.stdout.reconfigure(encoding="utf-8") except Exception: pass import numpy as np # Ensure project root is on path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from server.firewall_environment import FirewallEnvironment, TASK_CONFIGS from server.graders import run_deterministic_grade, grade_stats, TASK_SPECS from server.baseline.random_agent import random_policy, block_all_policy from server.baseline.heuristic_agent import heuristic_policy from server.utils.data_loader import TrafficGenerator, FEATURE_ORDER from server.utils.threat_engine import ThreatEngine # =================================================================== # Helpers # =================================================================== OK = "[OK]" FAIL = "[FAIL]" WARN = "[WARN]" def allow_all_policy(env, session_ids): return {sid: 0 for sid in session_ids} def separator(title): w = 72 print() print("=" * w) print(" " + title) print("=" * w) def subsection(title): print("\n -- {} --".format(title)) # =================================================================== # Test 1: Parameter configuration validation # =================================================================== def check_config_parameters(): separator("1. TASK CONFIG PARAMETER VALIDATION") required_keys = [ "max_steps", "benign_ratio", "threat_probability", "traffic_lambda", "budget", "noise_level", "stealth_multiplier", "session_ttl_benign", "session_ttl_malicious", "escalation_rate_mod", "false_flag_prob", "burst_prob", "burst_size_mult", ] all_ok = True for task, config in TASK_CONFIGS.items(): missing = [k for k in required_keys if k not in config] if missing: print(" {} {}: missing keys {}".format(FAIL, task, missing)) all_ok = False else: print(" {} {}: all {} parameters present".format(OK, task, len(required_keys))) # Validate ranges assert 0.0 <= config["noise_level"] <= 0.5, "noise_level out of range for {}".format(task) assert 0.5 <= config["stealth_multiplier"] <= 3.0, "stealth_multiplier out of range for {}".format(task) assert 1 <= config["session_ttl_benign"] <= 10, "session_ttl_benign out of range for {}".format(task) assert 1 <= config["session_ttl_malicious"] <= 10, "session_ttl_malicious out of range for {}".format(task) assert 0.5 <= config["escalation_rate_mod"] <= 3.0, "escalation_rate_mod out of range for {}".format(task) assert 0.0 <= config["false_flag_prob"] <= 0.5, "false_flag_prob out of range for {}".format(task) assert 0.0 <= config["burst_prob"] <= 0.5, "burst_prob out of range for {}".format(task) assert 1.0 <= config["burst_size_mult"] <= 5.0, "burst_size_mult out of range for {}".format(task) # Check difficulty progression subsection("Difficulty progression check") tasks = ["easy", "medium", "hard"] for param in ["noise_level", "stealth_multiplier", "escalation_rate_mod", "false_flag_prob", "burst_prob", "burst_size_mult"]: values = [TASK_CONFIGS[t][param] for t in tasks] monotonic = all(values[i] <= values[i + 1] for i in range(len(values) - 1)) status = OK if monotonic else WARN print(" {} {:25s}: easy={:.3f} med={:.3f} hard={:.3f}".format( status, param, values[0], values[1], values[2])) return all_ok # =================================================================== # Test 2: Run all policies across all tasks # =================================================================== def run_full_evaluation(): separator("2. POLICY EVALUATION ACROSS ALL TASKS") policies = { "random": random_policy(seed=42), "block_all": block_all_policy, "allow_all": allow_all_policy, "heuristic": heuristic_policy, } tasks = ["easy", "medium", "hard"] results = {} for task in tasks: subsection("Task: {} (threshold={})".format(task.upper(), TASK_SPECS[task].threshold)) results[task] = {} for pname, policy in policies.items(): env = FirewallEnvironment(seed=TASK_SPECS[task].seed) t0 = time.time() grade = run_deterministic_grade(env, task=task, policy=policy) elapsed = time.time() - t0 stats = env.get_network_stats() results[task][pname] = { "score": grade["score"], "passed": grade["passed"], "stats": stats, "elapsed": elapsed, } passed_str = "PASS " + OK if grade["passed"] else "FAIL " + FAIL print(" {:12s} score={:.4f} {} det={:.3f} fp={:.3f} eff={:.3f} ({:.2f}s)".format( pname, grade["score"], passed_str, stats["detection_rate"], stats["false_positive_rate"], stats["efficiency"], elapsed)) # Print extended metrics ff_seen = stats.get("false_flags_seen", 0) ff_acc = stats.get("false_flag_accuracy", 0) stealth_seen = stats.get("stealth_attacks_seen", 0) stealth_det = stats.get("stealth_detection_rate", 0) bursts = stats.get("burst_ticks", 0) print(" {:12s} false_flags={} (acc={:.3f}) stealth={} (det={:.3f}) bursts={}".format( "", ff_seen, ff_acc, stealth_seen, stealth_det, bursts)) return results # =================================================================== # Test 3: Policy ordering validation # =================================================================== def check_policy_ordering(results): separator("3. POLICY ORDERING VALIDATION") all_ok = True for task in ["easy", "medium", "hard"]: h_score = results[task]["heuristic"]["score"] r_score = results[task]["random"]["score"] b_score = results[task]["block_all"]["score"] a_score = results[task]["allow_all"]["score"] h_vs_r = OK if h_score > r_score else FAIL h_vs_b = OK if h_score > b_score else FAIL h_vs_a = OK if h_score > a_score else FAIL print(" {:8s} heuristic({:.4f}) > random({:.4f}) {}".format(task, h_score, r_score, h_vs_r)) print(" {:8s} heuristic({:.4f}) > block_all({:.4f}) {}".format("", h_score, b_score, h_vs_b)) print(" {:8s} heuristic({:.4f}) > allow_all({:.4f}) {}".format("", h_score, a_score, h_vs_a)) if not (h_score > r_score and h_score > b_score): all_ok = False return all_ok # =================================================================== # Test 4: Feature separation analysis # =================================================================== def check_feature_separation(): separator("4. FEATURE SEPARATION (BENIGN vs MALICIOUS)") gen = TrafficGenerator(seed=777) benign_vecs = [] for _ in range(200): s = gen.generate_benign_sessions(tick=0, count=1)[0] benign_vecs.append(gen.to_observation_vector(s)) mal_vecs = [] for scenario in ["port_scan_exploit_c2", "credential_stuffing_lateral", "supply_chain_compromise", "low_and_slow_apt", "ddos_amplification"]: for phase in range(4): for _ in range(10): s = gen.generate_malicious_sessions( tick=0, count=1, attack_phase=phase, scenario=scenario, )[0] mal_vecs.append(gen.to_observation_vector(s)) benign_arr = np.array(benign_vecs) mal_arr = np.array(mal_vecs) mean_diff = np.abs(benign_arr.mean(axis=0) - mal_arr.mean(axis=0)) top_features = np.argsort(mean_diff)[::-1] print(" Feature separation (|mean_diff|):") print(" {:30s} {:>10s} {:>12s} {:>8s} {:>8s}".format( "Feature", "Benign u", "Malicious u", "|D|", "Signal")) print(" " + "-" * 75) significant_count = 0 for idx in top_features: name = FEATURE_ORDER[idx] b_mean = benign_arr[:, idx].mean() m_mean = mal_arr[:, idx].mean() delta = mean_diff[idx] signal = "STRONG" if delta > 0.15 else ("MODERATE" if delta > 0.08 else "WEAK") if delta > 0.08: significant_count += 1 print(" {:30s} {:10.4f} {:12.4f} {:8.4f} {:>8s}".format( name, b_mean, m_mean, delta, signal)) print("\n Significant features (|D| > 0.08): {}/22".format(significant_count)) ok = significant_count >= 5 print(" {} Minimum 5 significant features: {}".format(OK if ok else FAIL, significant_count)) return ok # =================================================================== # Test 5: Noise effect validation # =================================================================== def check_noise_effect(): separator("5. OBSERVATION NOISE VALIDATION") all_ok = True for task in ["easy", "medium", "hard"]: noise_level = TASK_CONFIGS[task]["noise_level"] env = FirewallEnvironment(seed=42) env.reset(task=task, seed=42) if not env._session_queue: print(" {} No sessions for {}".format(WARN, task)) continue sid = env._session_queue[0] session = env.pending_sessions.get(sid) or env.inspected_sessions.get(sid) clean_obs = env.generator.to_observation_vector(session) noisy_obs = env._current_observation() diff = np.array(clean_obs) - np.array(noisy_obs) l2 = np.linalg.norm(diff) max_diff = np.max(np.abs(diff)) ok = True if noise_level > 0.0: ok = l2 > 0.0 all_ok = all_ok and ok status = OK if ok else FAIL print(" {} {:8s}: noise_s={:.3f} L2_diff={:.4f} max_abs_diff={:.4f}".format( status, task, noise_level, l2, max_diff)) return all_ok # =================================================================== # Test 6: Burst & False flag effect # =================================================================== def check_burst_and_false_flags(): separator("6. BURST & FALSE FLAG EFFECT VALIDATION") all_ok = True for task in ["medium", "hard"]: env = FirewallEnvironment(seed=42) grade = run_deterministic_grade(env, task=task, policy=heuristic_policy) stats = env.get_network_stats() ff_prob = TASK_CONFIGS[task]["false_flag_prob"] burst_prob = TASK_CONFIGS[task]["burst_prob"] ff_seen = stats.get("false_flags_seen", 0) bursts = stats.get("burst_ticks", 0) ff_ok = ff_seen > 0 if ff_prob > 0 else True burst_ok = bursts > 0 if burst_prob > 0 else True ok = ff_ok and burst_ok all_ok = all_ok and ok status = OK if ok else FAIL print(" {} {:8s}: false_flag_prob={:.2f} -> seen={:3d} | burst_prob={:.2f} -> ticks={:3d}".format( status, task, ff_prob, ff_seen, burst_prob, bursts)) return all_ok # =================================================================== # Test 7: Escalation rate effect # =================================================================== def check_escalation_effect(): separator("7. ESCALATION RATE MODIFIER VALIDATION") gen = TrafficGenerator(seed=100) results = {} # Use only 10 ticks to avoid phase saturation (max phase = 3) # which would mask the escalation rate differences for mod_label, mod_val in [("normal (1.0)", 1.0), ("fast (1.5)", 1.5), ("very_fast (2.0)", 2.0)]: engine = ThreatEngine(seed=200) total_phases = 0 n_attackers = 0 for _ in range(50): engine.maybe_spawn_attacker(0.5) for tick in range(10): engine.generate_attack_sessions( tick=tick, generator=gen, blocked_attackers=set(), escalation_rate_mod=mod_val, ) for a in list(engine._active_attackers.values()) + engine._dead_attackers: total_phases += a.phase n_attackers += 1 avg_phase = total_phases / max(n_attackers, 1) results[mod_label] = avg_phase print(" escalation_mod={:20s} attackers={:3d} avg_phase={:.2f}".format( mod_label, n_attackers, avg_phase)) labels = list(results.keys()) ok = results[labels[0]] <= results[labels[1]] <= results[labels[2]] status = OK if ok else WARN print(" {} Phase progression monotonic with escalation rate".format(status)) return ok # =================================================================== # Test 8: Budget & efficiency tracking # =================================================================== def check_budget_invariants(): separator("8. BUDGET & EFFICIENCY INVARIANTS") all_ok = True for task in ["easy", "medium", "hard"]: env = FirewallEnvironment(seed=42) env.reset(task=task, seed=42) initial_budget = env.initial_budget done = False while not done: sids = list(env.inspected_sessions.keys()) + list(env.pending_sessions.keys()) actions = heuristic_policy(env, sids) resp = env.step(actions) done = resp["done"] stats = env.get_network_stats() budget_ok = 0.0 <= env.budget_remaining <= initial_budget eff_ok = 0.0 <= stats["efficiency"] <= 1.0 det_ok = 0.0 <= stats["detection_rate"] <= 1.0 fp_ok = 0.0 <= stats["false_positive_rate"] <= 1.0 ok = budget_ok and eff_ok and det_ok and fp_ok all_ok = all_ok and ok status = OK if ok else FAIL print(" {} {:8s}: budget={:.1f}/{:.1f} eff={:.4f} det={:.4f} fp={:.4f}".format( status, task, env.budget_remaining, initial_budget, stats["efficiency"], stats["detection_rate"], stats["false_positive_rate"])) return all_ok # =================================================================== # Test 9: Determinism check # =================================================================== def check_determinism(): separator("9. DETERMINISM VALIDATION") all_ok = True for task in ["easy", "medium", "hard"]: env1 = FirewallEnvironment(seed=42) env2 = FirewallEnvironment(seed=42) policy = random_policy(seed=99) policy2 = random_policy(seed=99) g1 = run_deterministic_grade(env1, task=task, policy=policy) g2 = run_deterministic_grade(env2, task=task, policy=policy2) ok = g1["score"] == g2["score"] all_ok = all_ok and ok status = OK if ok else FAIL print(" {} {:8s}: run1={:.6f} run2={:.6f}".format(status, task, g1["score"], g2["score"])) return all_ok # =================================================================== # Main # =================================================================== def main(): print() print("+" + "=" * 70 + "+") print("| AI FIREWALL ENVIRONMENT -- ACCURACY & PARAMETER CHECKER |") print("+" + "=" * 70 + "+") t0 = time.time() checks = {} checks["config_params"] = check_config_parameters() results = run_full_evaluation() checks["policy_ordering"] = check_policy_ordering(results) checks["feature_separation"] = check_feature_separation() checks["noise_effect"] = check_noise_effect() checks["burst_false_flags"] = check_burst_and_false_flags() checks["escalation_effect"] = check_escalation_effect() checks["budget_invariants"] = check_budget_invariants() checks["determinism"] = check_determinism() # -- Summary -- separator("SUMMARY") total_elapsed = time.time() - t0 passed = sum(1 for v in checks.values() if v) total = len(checks) for name, ok in checks.items(): status = OK + " PASS" if ok else FAIL + " FAIL" print(" {} {}".format(status, name)) print("\n {}/{} checks passed ({:.1f}s total)".format(passed, total, total_elapsed)) if passed == total: print("\n ALL CHECKS PASSED -- Environment is accurate and well-calibrated!") else: print("\n {} check(s) need attention".format(total - passed)) return 0 if passed == total else 1 if __name__ == "__main__": sys.exit(main())