Spaces:
Running
Running
File size: 1,153 Bytes
90bdd23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | from __future__ import annotations
import json
from pathlib import Path
from runway_zero.qlearning import train_stage
def main() -> None:
out_dir = Path("results/trained")
out_dir.mkdir(parents=True, exist_ok=True)
episodes_by_stage = {1: 140, 2: 180, 3: 220}
summary = []
for stage, episodes in episodes_by_stage.items():
artifact = train_stage(stage=stage, episodes=episodes)
path = out_dir / f"q_policy_stage{stage}.json"
path.write_text(json.dumps(artifact, indent=2), encoding="utf-8")
rewards = [point["reward"] for point in artifact["learning_curve"]]
summary.append(
{
"stage": stage,
"episodes": episodes,
"first_20_avg": round(sum(rewards[:20]) / max(1, min(20, len(rewards))), 3),
"last_20_avg": round(sum(rewards[-20:]) / max(1, min(20, len(rewards))), 3),
"artifact": str(path),
}
)
(out_dir / "training_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()
|