File size: 2,485 Bytes
bc5030f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from acre.env.refactor_env import RefactorEnv
@dataclass(frozen=True)
class TrainConfig:
"""Configuration stub for training."""
total_steps: int = 5_000
seed: Optional[int] = None
model_path: str = "acre_agent.zip"
def train(*, env: Optional[RefactorEnv] = None, config: Optional[TrainConfig] = None) -> None:
"""
Train a PPO agent on `RefactorEnv` using Stable-Baselines3.
This is intentionally lightweight (hackathon-friendly) and focuses on a
working demo: basic training loop, simple logging, and saving the model.
"""
_config = config or TrainConfig()
_env = env or RefactorEnv(seed=_config.seed)
try:
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
except Exception as e: # pragma: no cover
print("Stable-Baselines3 is required for training. Install with `pip install -r requirements.txt`.")
print(f"Import error: {e}")
return None
class EpisodeRewardPrinter(BaseCallback):
"""Print episode reward when an episode ends (via Monitor)."""
def __init__(self) -> None:
super().__init__()
self.episode_count = 0
def _on_step(self) -> bool:
infos = self.locals.get("infos", [])
for info in infos:
ep = info.get("episode") if isinstance(info, dict) else None
if isinstance(ep, dict) and "r" in ep:
self.episode_count += 1
print(f"episode={self.episode_count} reward={ep['r']:.2f} length={int(ep.get('l', 0))}")
return True
# Wrap with Monitor so SB3 can compute episode stats and expose them in `info["episode"]`.
def make_env() -> RefactorEnv:
return Monitor(_env)
vec_env = DummyVecEnv([make_env])
model = PPO(
policy="MlpPolicy",
env=vec_env,
verbose=0,
seed=_config.seed,
n_steps=64,
batch_size=64,
)
print(f"Training PPO for {int(_config.total_steps)} timesteps...")
model.learn(total_timesteps=int(_config.total_steps), callback=EpisodeRewardPrinter())
model.save(_config.model_path)
print(f"Saved model to {_config.model_path!r}")
return None
|