File size: 4,635 Bytes
ccd6313 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """Run self-play training with adaptive curriculum.
Usage:
python scripts/run_self_play.py # Default: 30 rounds
python scripts/run_self_play.py --rounds 50 # Custom round count
python scripts/run_self_play.py --output results.json # Custom output path
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
# Force UTF-8 output on Windows
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
if sys.stdout.encoding != "utf-8":
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
# Ensure project root is on path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from server.self_play.arena import SelfPlayArena
from server.baseline.heuristic_agent import heuristic_policy
def main():
parser = argparse.ArgumentParser(
description="Run self-play training with adaptive curriculum"
)
parser.add_argument(
"--rounds", type=int, default=30,
help="Number of training rounds (default: 30)"
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed (default: 42)"
)
parser.add_argument(
"--threshold", type=float, default=0.78,
help="Pass threshold for mastery gating (default: 0.78)"
)
parser.add_argument(
"--mastery-window", type=int, default=3,
help="Consecutive passes needed for difficulty advance (default: 3)"
)
parser.add_argument(
"--output", type=str, default="self_play_results.json",
help="Output JSON path (default: self_play_results.json)"
)
parser.add_argument(
"--no-graphs", action="store_true",
help="Skip graph generation after training"
)
args = parser.parse_args()
print()
print("+" + "=" * 70 + "+")
print("| AI FIREWALL β SELF-PLAY ADAPTIVE CURRICULUM TRAINING" + " " * 16 + "|")
print("+" + "=" * 70 + "+")
print()
print(" Config:")
print(" Rounds: {}".format(args.rounds))
print(" Seed: {}".format(args.seed))
print(" Pass threshold: {}".format(args.threshold))
print(" Mastery window: {} consecutive passes".format(args.mastery_window))
print(" Output: {}".format(args.output))
print(" Policy: heuristic (8-rule baseline)")
t0 = time.time()
arena = SelfPlayArena(
seed=args.seed,
mastery_window=args.mastery_window,
pass_threshold=args.threshold,
)
results = arena.train(
policy=heuristic_policy,
num_rounds=args.rounds,
verbose=True,
)
# Save results
output_path = Path(args.output)
arena.save_history(output_path)
total_time = time.time() - t0
print(" Results saved to: {}".format(output_path.resolve()))
print(" Total training time: {:.1f}s".format(total_time))
# Final assessment
if results:
final_elo = results[-1].elo
start_elo = results[0].elo - results[0].elo_delta
growth = final_elo - start_elo
pass_rate = sum(1 for r in results if r.passed) / len(results)
print()
if growth > 50:
print(" Agent showed SIGNIFICANT skill growth ({:+.0f} Elo)".format(growth))
elif growth > 0:
print(" Agent showed MODERATE skill growth ({:+.0f} Elo)".format(growth))
else:
print(" Agent did NOT improve ({:+.0f} Elo) β policy may need updating".format(growth))
if pass_rate > 0.8:
print(" Pass rate {:.0%} β agent handles adaptive curriculum well".format(pass_rate))
elif pass_rate > 0.5:
print(" Pass rate {:.0%} β room for policy improvement".format(pass_rate))
else:
print(" Pass rate {:.0%} β agent struggles with generated challenges".format(pass_rate))
# ββ Generate performance graphs ββ
if not args.no_graphs:
print()
print("+" + "-" * 70 + "+")
print("| GENERATING PERFORMANCE GRAPHS" + " " * 39 + "|")
print("+" + "-" * 70 + "+")
print()
try:
from scripts.generate_performance_matrix import generate_graphs
generate_graphs(
input_json=str(output_path.resolve()),
output_dir=str(Path("output").resolve()),
)
except Exception as e:
print(" [GRAPHS] Warning: Could not generate graphs: {}".format(e))
else:
print(" Skipping graph generation (--no-graphs)")
return 0
if __name__ == "__main__":
sys.exit(main())
|