File size: 5,732 Bytes
1659dd8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | #!/usr/bin/env python3
"""Smoke test: download TIL repo via snapshot_download, verify imports, run 100 steps, push dummy checkpoint."""
import os, sys, subprocess
print("="*60)
print("SMOKE TEST: HF Job private repo access + training basics")
print("="*60)
# 1. Test snapshot_download of private Space
print("\n[1/5] Downloading TIL repo via snapshot_download...")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="e-rong/til-26-ae",
repo_type="space",
local_dir="/app/til-26-ae-repo",
)
print(" ✓ Downloaded")
print(" Listing repo root:")
for root, dirs, files in os.walk("/app/til-26-ae-repo"):
level = root.replace("/app/til-26-ae-repo", "").count(os.sep)
indent = " " * 2 * level
print(f"{indent}{os.path.basename(root)}/")
subindent = " " * 2 * (level + 1)
for f in files[:5]:
print(f"{subindent}{f}")
if len(files) > 5:
print(f"{subindent}... ({len(files)-5} more files)")
# 2. Install TIL environment
print("\n[2/5] Installing TIL environment...")
# Find the actual package root (contains pyproject.toml)
PKG_ROOT = None
for root, dirs, files in os.walk("/app/til-26-ae-repo"):
if "pyproject.toml" in files:
PKG_ROOT = root
break
if PKG_ROOT is None:
raise RuntimeError("Could not find pyproject.toml in downloaded repo")
print(f" Package root found: {PKG_ROOT}")
subprocess.run(["pip", "install", "-e", "."], cwd=PKG_ROOT, check=True)
print(" ✓ Installed")
# 3. Verify imports
print("\n[3/5] Verifying imports...")
sys.path.insert(0, PKG_ROOT)
from til_environment.bomberman_env import Bomberman
from til_environment.config import default_config
from pettingzoo.utils.conversions import aec_to_parallel
print(" ✓ Imports OK")
# 4. Run 100 steps of dummy training
print("\n[4/5] Running 100 training steps...")
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.monitor import Monitor
import gymnasium
from gymnasium.spaces import Box, Discrete
import numpy as np
class QuickEnv(gymnasium.Env):
def __init__(self):
super().__init__()
cfg = default_config()
cfg.env.render_mode = None
raw = Bomberman(cfg)
self._parallel_env = aec_to_parallel(raw)
self.agent_id = "agent_0"
self._episode_count = 0
self.action_space = Discrete(6)
vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
av = vl * vw * 25
br = int(cfg.entities.base.vision_radius)
bs = 2 * br + 1
bv = bs * bs * 25
self._obs_size = av + bv + 11
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)
self._last_action_mask = None
self._last_obs_dict = None
def reset(self, seed=None, options=None):
self._episode_count += 1
obs_dict, _ = self._parallel_env.reset(seed=self._episode_count, options=options)
self._last_obs_dict = obs_dict
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
return self._flatten(obs_dict[self.agent_id]), {}
def step(self, action):
actions = {self.agent_id: action}
for aid, obs in self._last_obs_dict.items():
if aid != self.agent_id:
valid = np.where(obs["action_mask"] == 1)[0]
actions[aid] = int(np.random.choice(valid)) if len(valid) > 0 else 0
obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
self._last_obs_dict = obs_dict
if self.agent_id not in obs_dict:
return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
obs = self._flatten(obs_dict[self.agent_id])
r = float(rewards.get(self.agent_id, 0.0))
done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
return obs, r, done, False, infos.get(self.agent_id, {})
def action_masks(self):
return self._last_action_mask
def _flatten(self, od):
return np.concatenate([
od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
np.array([od["direction"]], dtype=np.float32),
od["location"].flatten().astype(np.float32),
od["base_location"].flatten().astype(np.float32),
od["health"].flatten().astype(np.float32),
np.array([od["frozen_ticks"]], dtype=np.float32),
od["base_health"].flatten().astype(np.float32),
od["team_resources"].flatten().astype(np.float32),
np.array([od["team_bombs"]], dtype=np.float32),
np.array([od["step"]], dtype=np.float32),
], dtype=np.float32)
env = ActionMasker(QuickEnv(), lambda e: e.action_masks())
env = Monitor(env)
model = MaskablePPO(
"MlpPolicy", env,
learning_rate=3e-4, n_steps=128, batch_size=32, n_epochs=2,
gamma=0.99, clip_range=0.2, ent_coef=0.01,
verbose=0, device="cuda",
)
model.learn(total_timesteps=100, progress_bar=False)
print(" ✓ 100 steps completed")
# 5. Push dummy checkpoint to Hub
print("\n[5/5] Pushing dummy checkpoint to Hub...")
from huggingface_hub import HfApi
model.save("/app/smoke_test_ckpt.zip")
HfApi().upload_file(
path_or_fileobj="/app/smoke_test_ckpt.zip",
path_in_repo="smoke_test_ckpt.zip",
repo_id="E-Rong/til-26-ae-agent",
repo_type="model",
)
print(" ✓ Pushed to Hub")
print("\n" + "="*60)
print("SMOKE TEST PASSED — Ready for full training job")
print("="*60)
|