File size: 16,962 Bytes
ccd6313 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 | """Comprehensive accuracy and parameter validation for the Firewall environment.
Runs all tasks x all policies and validates:
1. New parameter effects (noise, stealth, bursts, false flags, TTL, escalation)
2. Grading accuracy across difficulty tiers
3. Policy ordering: heuristic > block-all > random
4. Feature distribution separation (benign vs malicious)
5. Environment invariants (budget, observations, session lifecycle)
"""
from __future__ import annotations
import os
import sys
import time
from pathlib import Path
# Force UTF-8 output on Windows
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
if sys.stdout.encoding != "utf-8":
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
import numpy as np
# Ensure project root is on path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from server.firewall_environment import FirewallEnvironment, TASK_CONFIGS
from server.graders import run_deterministic_grade, grade_stats, TASK_SPECS
from server.baseline.random_agent import random_policy, block_all_policy
from server.baseline.heuristic_agent import heuristic_policy
from server.utils.data_loader import TrafficGenerator, FEATURE_ORDER
from server.utils.threat_engine import ThreatEngine
# ===================================================================
# Helpers
# ===================================================================
OK = "[OK]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def allow_all_policy(env, session_ids):
return {sid: 0 for sid in session_ids}
def separator(title):
w = 72
print()
print("=" * w)
print(" " + title)
print("=" * w)
def subsection(title):
print("\n -- {} --".format(title))
# ===================================================================
# Test 1: Parameter configuration validation
# ===================================================================
def check_config_parameters():
separator("1. TASK CONFIG PARAMETER VALIDATION")
required_keys = [
"max_steps", "benign_ratio", "threat_probability", "traffic_lambda",
"budget", "noise_level", "stealth_multiplier", "session_ttl_benign",
"session_ttl_malicious", "escalation_rate_mod", "false_flag_prob",
"burst_prob", "burst_size_mult",
]
all_ok = True
for task, config in TASK_CONFIGS.items():
missing = [k for k in required_keys if k not in config]
if missing:
print(" {} {}: missing keys {}".format(FAIL, task, missing))
all_ok = False
else:
print(" {} {}: all {} parameters present".format(OK, task, len(required_keys)))
# Validate ranges
assert 0.0 <= config["noise_level"] <= 0.5, "noise_level out of range for {}".format(task)
assert 0.5 <= config["stealth_multiplier"] <= 3.0, "stealth_multiplier out of range for {}".format(task)
assert 1 <= config["session_ttl_benign"] <= 10, "session_ttl_benign out of range for {}".format(task)
assert 1 <= config["session_ttl_malicious"] <= 10, "session_ttl_malicious out of range for {}".format(task)
assert 0.5 <= config["escalation_rate_mod"] <= 3.0, "escalation_rate_mod out of range for {}".format(task)
assert 0.0 <= config["false_flag_prob"] <= 0.5, "false_flag_prob out of range for {}".format(task)
assert 0.0 <= config["burst_prob"] <= 0.5, "burst_prob out of range for {}".format(task)
assert 1.0 <= config["burst_size_mult"] <= 5.0, "burst_size_mult out of range for {}".format(task)
# Check difficulty progression
subsection("Difficulty progression check")
tasks = ["easy", "medium", "hard"]
for param in ["noise_level", "stealth_multiplier", "escalation_rate_mod",
"false_flag_prob", "burst_prob", "burst_size_mult"]:
values = [TASK_CONFIGS[t][param] for t in tasks]
monotonic = all(values[i] <= values[i + 1] for i in range(len(values) - 1))
status = OK if monotonic else WARN
print(" {} {:25s}: easy={:.3f} med={:.3f} hard={:.3f}".format(
status, param, values[0], values[1], values[2]))
return all_ok
# ===================================================================
# Test 2: Run all policies across all tasks
# ===================================================================
def run_full_evaluation():
separator("2. POLICY EVALUATION ACROSS ALL TASKS")
policies = {
"random": random_policy(seed=42),
"block_all": block_all_policy,
"allow_all": allow_all_policy,
"heuristic": heuristic_policy,
}
tasks = ["easy", "medium", "hard"]
results = {}
for task in tasks:
subsection("Task: {} (threshold={})".format(task.upper(), TASK_SPECS[task].threshold))
results[task] = {}
for pname, policy in policies.items():
env = FirewallEnvironment(seed=TASK_SPECS[task].seed)
t0 = time.time()
grade = run_deterministic_grade(env, task=task, policy=policy)
elapsed = time.time() - t0
stats = env.get_network_stats()
results[task][pname] = {
"score": grade["score"],
"passed": grade["passed"],
"stats": stats,
"elapsed": elapsed,
}
passed_str = "PASS " + OK if grade["passed"] else "FAIL " + FAIL
print(" {:12s} score={:.4f} {} det={:.3f} fp={:.3f} eff={:.3f} ({:.2f}s)".format(
pname, grade["score"], passed_str,
stats["detection_rate"], stats["false_positive_rate"],
stats["efficiency"], elapsed))
# Print extended metrics
ff_seen = stats.get("false_flags_seen", 0)
ff_acc = stats.get("false_flag_accuracy", 0)
stealth_seen = stats.get("stealth_attacks_seen", 0)
stealth_det = stats.get("stealth_detection_rate", 0)
bursts = stats.get("burst_ticks", 0)
print(" {:12s} false_flags={} (acc={:.3f}) stealth={} (det={:.3f}) bursts={}".format(
"", ff_seen, ff_acc, stealth_seen, stealth_det, bursts))
return results
# ===================================================================
# Test 3: Policy ordering validation
# ===================================================================
def check_policy_ordering(results):
separator("3. POLICY ORDERING VALIDATION")
all_ok = True
for task in ["easy", "medium", "hard"]:
h_score = results[task]["heuristic"]["score"]
r_score = results[task]["random"]["score"]
b_score = results[task]["block_all"]["score"]
a_score = results[task]["allow_all"]["score"]
h_vs_r = OK if h_score > r_score else FAIL
h_vs_b = OK if h_score > b_score else FAIL
h_vs_a = OK if h_score > a_score else FAIL
print(" {:8s} heuristic({:.4f}) > random({:.4f}) {}".format(task, h_score, r_score, h_vs_r))
print(" {:8s} heuristic({:.4f}) > block_all({:.4f}) {}".format("", h_score, b_score, h_vs_b))
print(" {:8s} heuristic({:.4f}) > allow_all({:.4f}) {}".format("", h_score, a_score, h_vs_a))
if not (h_score > r_score and h_score > b_score):
all_ok = False
return all_ok
# ===================================================================
# Test 4: Feature separation analysis
# ===================================================================
def check_feature_separation():
separator("4. FEATURE SEPARATION (BENIGN vs MALICIOUS)")
gen = TrafficGenerator(seed=777)
benign_vecs = []
for _ in range(200):
s = gen.generate_benign_sessions(tick=0, count=1)[0]
benign_vecs.append(gen.to_observation_vector(s))
mal_vecs = []
for scenario in ["port_scan_exploit_c2", "credential_stuffing_lateral",
"supply_chain_compromise", "low_and_slow_apt", "ddos_amplification"]:
for phase in range(4):
for _ in range(10):
s = gen.generate_malicious_sessions(
tick=0, count=1, attack_phase=phase, scenario=scenario,
)[0]
mal_vecs.append(gen.to_observation_vector(s))
benign_arr = np.array(benign_vecs)
mal_arr = np.array(mal_vecs)
mean_diff = np.abs(benign_arr.mean(axis=0) - mal_arr.mean(axis=0))
top_features = np.argsort(mean_diff)[::-1]
print(" Feature separation (|mean_diff|):")
print(" {:30s} {:>10s} {:>12s} {:>8s} {:>8s}".format(
"Feature", "Benign u", "Malicious u", "|D|", "Signal"))
print(" " + "-" * 75)
significant_count = 0
for idx in top_features:
name = FEATURE_ORDER[idx]
b_mean = benign_arr[:, idx].mean()
m_mean = mal_arr[:, idx].mean()
delta = mean_diff[idx]
signal = "STRONG" if delta > 0.15 else ("MODERATE" if delta > 0.08 else "WEAK")
if delta > 0.08:
significant_count += 1
print(" {:30s} {:10.4f} {:12.4f} {:8.4f} {:>8s}".format(
name, b_mean, m_mean, delta, signal))
print("\n Significant features (|D| > 0.08): {}/22".format(significant_count))
ok = significant_count >= 5
print(" {} Minimum 5 significant features: {}".format(OK if ok else FAIL, significant_count))
return ok
# ===================================================================
# Test 5: Noise effect validation
# ===================================================================
def check_noise_effect():
separator("5. OBSERVATION NOISE VALIDATION")
all_ok = True
for task in ["easy", "medium", "hard"]:
noise_level = TASK_CONFIGS[task]["noise_level"]
env = FirewallEnvironment(seed=42)
env.reset(task=task, seed=42)
if not env._session_queue:
print(" {} No sessions for {}".format(WARN, task))
continue
sid = env._session_queue[0]
session = env.pending_sessions.get(sid) or env.inspected_sessions.get(sid)
clean_obs = env.generator.to_observation_vector(session)
noisy_obs = env._current_observation()
diff = np.array(clean_obs) - np.array(noisy_obs)
l2 = np.linalg.norm(diff)
max_diff = np.max(np.abs(diff))
ok = True
if noise_level > 0.0:
ok = l2 > 0.0
all_ok = all_ok and ok
status = OK if ok else FAIL
print(" {} {:8s}: noise_s={:.3f} L2_diff={:.4f} max_abs_diff={:.4f}".format(
status, task, noise_level, l2, max_diff))
return all_ok
# ===================================================================
# Test 6: Burst & False flag effect
# ===================================================================
def check_burst_and_false_flags():
separator("6. BURST & FALSE FLAG EFFECT VALIDATION")
all_ok = True
for task in ["medium", "hard"]:
env = FirewallEnvironment(seed=42)
grade = run_deterministic_grade(env, task=task, policy=heuristic_policy)
stats = env.get_network_stats()
ff_prob = TASK_CONFIGS[task]["false_flag_prob"]
burst_prob = TASK_CONFIGS[task]["burst_prob"]
ff_seen = stats.get("false_flags_seen", 0)
bursts = stats.get("burst_ticks", 0)
ff_ok = ff_seen > 0 if ff_prob > 0 else True
burst_ok = bursts > 0 if burst_prob > 0 else True
ok = ff_ok and burst_ok
all_ok = all_ok and ok
status = OK if ok else FAIL
print(" {} {:8s}: false_flag_prob={:.2f} -> seen={:3d} | burst_prob={:.2f} -> ticks={:3d}".format(
status, task, ff_prob, ff_seen, burst_prob, bursts))
return all_ok
# ===================================================================
# Test 7: Escalation rate effect
# ===================================================================
def check_escalation_effect():
separator("7. ESCALATION RATE MODIFIER VALIDATION")
gen = TrafficGenerator(seed=100)
results = {}
# Use only 10 ticks to avoid phase saturation (max phase = 3)
# which would mask the escalation rate differences
for mod_label, mod_val in [("normal (1.0)", 1.0), ("fast (1.5)", 1.5), ("very_fast (2.0)", 2.0)]:
engine = ThreatEngine(seed=200)
total_phases = 0
n_attackers = 0
for _ in range(50):
engine.maybe_spawn_attacker(0.5)
for tick in range(10):
engine.generate_attack_sessions(
tick=tick, generator=gen,
blocked_attackers=set(),
escalation_rate_mod=mod_val,
)
for a in list(engine._active_attackers.values()) + engine._dead_attackers:
total_phases += a.phase
n_attackers += 1
avg_phase = total_phases / max(n_attackers, 1)
results[mod_label] = avg_phase
print(" escalation_mod={:20s} attackers={:3d} avg_phase={:.2f}".format(
mod_label, n_attackers, avg_phase))
labels = list(results.keys())
ok = results[labels[0]] <= results[labels[1]] <= results[labels[2]]
status = OK if ok else WARN
print(" {} Phase progression monotonic with escalation rate".format(status))
return ok
# ===================================================================
# Test 8: Budget & efficiency tracking
# ===================================================================
def check_budget_invariants():
separator("8. BUDGET & EFFICIENCY INVARIANTS")
all_ok = True
for task in ["easy", "medium", "hard"]:
env = FirewallEnvironment(seed=42)
env.reset(task=task, seed=42)
initial_budget = env.initial_budget
done = False
while not done:
sids = list(env.inspected_sessions.keys()) + list(env.pending_sessions.keys())
actions = heuristic_policy(env, sids)
resp = env.step(actions)
done = resp["done"]
stats = env.get_network_stats()
budget_ok = 0.0 <= env.budget_remaining <= initial_budget
eff_ok = 0.0 <= stats["efficiency"] <= 1.0
det_ok = 0.0 <= stats["detection_rate"] <= 1.0
fp_ok = 0.0 <= stats["false_positive_rate"] <= 1.0
ok = budget_ok and eff_ok and det_ok and fp_ok
all_ok = all_ok and ok
status = OK if ok else FAIL
print(" {} {:8s}: budget={:.1f}/{:.1f} eff={:.4f} det={:.4f} fp={:.4f}".format(
status, task, env.budget_remaining, initial_budget,
stats["efficiency"], stats["detection_rate"], stats["false_positive_rate"]))
return all_ok
# ===================================================================
# Test 9: Determinism check
# ===================================================================
def check_determinism():
separator("9. DETERMINISM VALIDATION")
all_ok = True
for task in ["easy", "medium", "hard"]:
env1 = FirewallEnvironment(seed=42)
env2 = FirewallEnvironment(seed=42)
policy = random_policy(seed=99)
policy2 = random_policy(seed=99)
g1 = run_deterministic_grade(env1, task=task, policy=policy)
g2 = run_deterministic_grade(env2, task=task, policy=policy2)
ok = g1["score"] == g2["score"]
all_ok = all_ok and ok
status = OK if ok else FAIL
print(" {} {:8s}: run1={:.6f} run2={:.6f}".format(status, task, g1["score"], g2["score"]))
return all_ok
# ===================================================================
# Main
# ===================================================================
def main():
print()
print("+" + "=" * 70 + "+")
print("| AI FIREWALL ENVIRONMENT -- ACCURACY & PARAMETER CHECKER |")
print("+" + "=" * 70 + "+")
t0 = time.time()
checks = {}
checks["config_params"] = check_config_parameters()
results = run_full_evaluation()
checks["policy_ordering"] = check_policy_ordering(results)
checks["feature_separation"] = check_feature_separation()
checks["noise_effect"] = check_noise_effect()
checks["burst_false_flags"] = check_burst_and_false_flags()
checks["escalation_effect"] = check_escalation_effect()
checks["budget_invariants"] = check_budget_invariants()
checks["determinism"] = check_determinism()
# -- Summary --
separator("SUMMARY")
total_elapsed = time.time() - t0
passed = sum(1 for v in checks.values() if v)
total = len(checks)
for name, ok in checks.items():
status = OK + " PASS" if ok else FAIL + " FAIL"
print(" {} {}".format(status, name))
print("\n {}/{} checks passed ({:.1f}s total)".format(passed, total, total_elapsed))
if passed == total:
print("\n ALL CHECKS PASSED -- Environment is accurate and well-calibrated!")
else:
print("\n {} check(s) need attention".format(total - passed))
return 0 if passed == total else 1
if __name__ == "__main__":
sys.exit(main())
|