til-26-ae-agent / phase2_job.py
E-Rong's picture
Add Phase 2 HF Job training script
2f3c7cd verified
#!/usr/bin/env python3
"""Phase 2 training job - runs in HF Jobs, resumes from Hub checkpoint."""
import os, sys, subprocess, numpy as np, torch, gymnasium
from gymnasium.spaces import Box, Discrete
# Install TIL environment from source
TIL_REPO = "e-rong/til-26-ae"
TIL_PATH = "/app/til-26-ae-repo/til-26-ae"
if not os.path.exists(TIL_PATH):
subprocess.run(["git", "clone", f"https://huggingface.co/spaces/{TIL_REPO}", "/app/til-26-ae-repo"], check=True)
subprocess.run(["pip", "install", "-e", "."], cwd=TIL_PATH, check=True)
sys.path.insert(0, TIL_PATH)
from til_environment.bomberman_env import Bomberman
from til_environment.config import default_config
from pettingzoo.utils.conversions import aec_to_parallel
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from huggingface_hub import HfApi, hf_hub_download
HUB_REPO = "E-Rong/til-26-ae-agent"
DATA_DIR = "/app/data"
os.makedirs(DATA_DIR, exist_ok=True)
def hub_push(local_path, repo_path):
try:
HfApi().upload_file(path_or_fileobj=local_path, path_in_repo=repo_path,
repo_id=HUB_REPO, repo_type="model")
print(f" -> pushed {repo_path}")
except Exception as e:
print(f" -> push failed: {e}")
class BombermanSingleAgentEnv(gymnasium.Env):
def __init__(self, cfg=None):
super().__init__()
self.cfg = cfg or default_config()
self.cfg.env.render_mode = None
raw = Bomberman(self.cfg)
self._parallel_env = aec_to_parallel(raw)
self.agent_id = "agent_0"
self._episode_count = 0
self.action_space = Discrete(6)
self._last_action_mask = None
self._obs_size = None
self._last_obs_dict = None
self._compute_obs_space()
def _compute_obs_space(self):
cfg = self.cfg
vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
av = vl * vw * 25
br = int(cfg.entities.base.vision_radius)
bs = 2 * br + 1
bv = bs * bs * 25
self._obs_size = av + bv + 11
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
def reset(self, seed=None, options=None):
self._episode_count += 1
obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
self._last_obs_dict = obs_dict
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
return self._flatten(obs_dict[self.agent_id]), {}
def step(self, action):
actions = {self.agent_id: action}
for aid, obs in self._last_obs_dict.items():
if aid != self.agent_id:
valid = np.where(obs["action_mask"] == 1)[0]
actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
self._last_obs_dict = obs_dict
if self.agent_id not in obs_dict:
return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
obs = self._flatten(obs_dict[self.agent_id])
r = float(rewards.get(self.agent_id, 0.0))
done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
return obs, r, done, False, infos.get(self.agent_id, {})
def action_masks(self):
return self._last_action_mask
def _flatten(self, od):
return np.concatenate([
od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
np.array([od["direction"]], dtype=np.float32),
od["location"].flatten().astype(np.float32),
od["base_location"].flatten().astype(np.float32),
od["health"].flatten().astype(np.float32),
np.array([od["frozen_ticks"]], dtype=np.float32),
od["base_health"].flatten().astype(np.float32),
od["team_resources"].flatten().astype(np.float32),
np.array([od["team_bombs"]], dtype=np.float32),
np.array([od["step"]], dtype=np.float32),
], dtype=np.float32)
def close(self):
self._parallel_env.close()
class RewardShapingWrapper(gymnasium.Wrapper):
"""Visit-count exploration with adaptive annealing."""
def __init__(self, env, adaptive_k=1.2, base_explore_weight=0.5):
super().__init__(env)
self.adaptive_k = adaptive_k
self.base_explore_weight = base_explore_weight
self._visit_counts = None
self._grid_size = 16
self._avg_enemy_deaths = 0.0
self._explore_weight = base_explore_weight
def reset(self, **kwargs):
self._visit_counts = np.zeros((self._grid_size, self._grid_size), dtype=np.int32)
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, truncated, info = self.env.step(action)
pos = info.get("location", None)
bonus = 0.0
if pos is not None:
x, y = int(pos[0]), int(pos[1])
if 0 <= x < self._grid_size and 0 <= y < self._grid_size:
visits = self._visit_counts[x, y]
bonus = 1.0 / (1.0 + visits)
self._visit_counts[x, y] += 1
if done:
alpha = 1.0 - np.tanh(self.adaptive_k * self._avg_enemy_deaths)
self._explore_weight = self.base_explore_weight * max(0.1, alpha)
if reward > 20.0:
self._avg_enemy_deaths = 0.95 * self._avg_enemy_deaths + 0.05 * 1.0
shaped = reward + self._explore_weight * bonus
info["raw_reward"] = reward
info["explore_bonus"] = bonus
return obs, shaped, done, truncated, info
def action_masks(self):
return self.env.action_masks()
class HubCheckpointCallback(CheckpointCallback):
"""Saves locally + pushes to Hub."""
def _on_step(self) -> bool:
if self.num_timesteps % self.save_freq == 0:
path = os.path.join(self.save_path, f"phase2_ckpt_{self.num_timesteps}.zip")
self.model.save(path)
hub_push(path, f"phase2_ckpt_{self.num_timesteps}.zip")
return True
def main():
print("=" * 60)
print("PHASE 2: Adaptive Exploration Annealing")
print("=" * 60)
# Download latest checkpoint
latest = None
for ckpt in ["phase2_ckpt_600352.zip", "phase2_ckpt_550352.zip", "phase1_final.zip"]:
try:
latest = hf_hub_download(repo_id=HUB_REPO, filename=ckpt, repo_type="model", local_dir=DATA_DIR)
print(f"Downloaded checkpoint: {ckpt}")
break
except Exception:
print(f" {ckpt} not found, trying next...")
if latest is None:
raise RuntimeError("No checkpoint found on Hub!")
# Environment
cfg = default_config()
cfg.env.render_mode = None
base = BombermanSingleAgentEnv(cfg=cfg)
env = ActionMasker(RewardShapingWrapper(base), lambda e: e.action_masks())
env = Monitor(env)
# Load model
print(f"Loading model from {latest}...")
model = MaskablePPO.load(latest, env=env)
start_ts = model.num_timesteps
remaining = 1000000 - start_ts
print(f"Current: {start_ts}, remaining: {remaining}, target: 1,000,352")
# Train
cb = HubCheckpointCallback(save_freq=50000, save_path=DATA_DIR, name_prefix="phase2")
model.learn(total_timesteps=remaining, callback=cb, progress_bar=False, reset_num_timesteps=False)
# Save final
final = os.path.join(DATA_DIR, "phase2_final.zip")
model.save(final)
hub_push(final, "phase2_final.zip")
env.close()
print("\n=== Phase 2 COMPLETE ===")
print(f"Final timestep: {model.num_timesteps}")
# Evaluation
print("\n=== EVALUATION (100 eps vs Random) ===")
raw = Bomberman(default_config())
env = aec_to_parallel(raw)
wins = 0; total_r = 0; lens = []; bombs = 0
for ep in range(100):
obs, _ = env.reset(seed=ep+50000)
ep_r = 0; steps = 0; done = False; ep_bombs = 0
while not done:
if "agent_0" not in obs: break
ao = obs["agent_0"]
mask = np.array(ao.get("action_mask", [1]*6), dtype=bool)
vec = np.concatenate([
np.array(ao["agent_viewcone"], np.float32).flatten(),
np.array(ao["base_viewcone"], np.float32).flatten(),
np.array([ao["direction"]], np.float32),
np.array(ao["location"], np.float32).flatten(),
np.array(ao["base_location"], np.float32).flatten(),
np.array(ao["health"], np.float32).flatten(),
np.array([ao["frozen_ticks"]], np.float32),
np.array(ao["base_health"], np.float32).flatten(),
np.array(ao["team_resources"], np.float32).flatten(),
np.array([ao["team_bombs"]], np.float32),
np.array([ao["step"]], np.float32),
], dtype=np.float32)
action, _ = model.predict(vec, action_masks=mask, deterministic=True)
if int(action) == 5: ep_bombs += 1
acts = {"agent_0": int(action)}
for aid, o in obs.items():
if aid != "agent_0":
v = np.where(np.array(o["action_mask"]) == 1)[0]
acts[aid] = int(np.random.choice(v)) if len(v) > 0 else 4
obs, rewards, terminations, truncations, _ = env.step(acts)
ep_r += rewards.get("agent_0", 0)
steps += 1
done = terminations.get("agent_0", False) or truncations.get("agent_0", False) or "agent_0" not in obs
total_r += ep_r; lens.append(steps); bombs += ep_bombs
if ep_r > 10: wins += 1
env.close()
results = (
f"=== Phase 2 Evaluation ===\n"
f"Episodes: 100\n"
f"Win Rate: {wins/100:.1%}\n"
f"Avg Reward: {total_r/100:.1f}\n"
f"Avg Length: {sum(lens)/len(lens):.1f}\n"
f"Avg Bombs: {bombs/100:.1f}\n"
)
print(results)
with open("/app/phase2_eval.txt", "w") as f:
f.write(results)
hub_push("/app/phase2_eval.txt", "phase2_eval_results.txt")
print("\n✅ ALL DONE!")
if __name__ == "__main__":
main()