File size: 12,998 Bytes
7d18d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#!/usr/bin/env python3
"""Phase 3: Rule-based curriculum training - 1M steps with progressive opponents."""
import os, sys, subprocess, numpy as np, torch, gymnasium
from gymnasium.spaces import Box, Discrete

# ── 1. Download TIL env via snapshot_download ──
print("[1/5] Downloading TIL repo...")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="e-rong/til-26-ae", repo_type="space", local_dir="/app/til-26-ae-repo")
PKG_ROOT = None
for root, dirs, files in os.walk("/app/til-26-ae-repo"):
    if "pyproject.toml" in files:
        PKG_ROOT = root
        break
if PKG_ROOT is None:
    raise RuntimeError("pyproject.toml not found")
subprocess.run(["pip", "install", "-e", "."], cwd=PKG_ROOT, check=True)
sys.path.insert(0, PKG_ROOT)

from til_environment.bomberman_env import Bomberman
from til_environment.config import default_config
from pettingzoo.utils.conversions import aec_to_parallel
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.monitor import Monitor
from huggingface_hub import HfApi, hf_hub_download

HUB_REPO = "E-Rong/til-26-ae-agent"
DATA_DIR = "/app/data"
os.makedirs(DATA_DIR, exist_ok=True)


def hub_push(local_path, repo_path):
    try:
        HfApi().upload_file(path_or_fileobj=local_path, path_in_repo=repo_path,
                          repo_id=HUB_REPO, repo_type="model")
        print(f"  -> pushed {repo_path}")
    except Exception as e:
        print(f"  -> push failed: {e}")


# ── Opponent Policies ──
def static_opponent(obs):
    """Never moves, never places bombs."""
    return 4  # STAY


def random_valid_opponent(obs):
    """Random valid action (Phase 1 style)."""
    mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
    valid = np.where(mask)[0]
    return int(np.random.choice(valid)) if len(valid) > 0 else 4


def simple_bomb_opponent(obs):
    """Moves randomly but places bombs when enemies are visible."""
    mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
    # Check if enemies visible in viewcone
    view = np.array(obs.get("agent_viewcone", np.zeros((7,5,25))))
    if view.shape[-1] >= 11:  # ENEMY_AGENT channel exists
        enemy_present = np.any(view[..., 10] > 0)  # ENEMY_AGENT channel
        if enemy_present and mask[5]:  # PLACE_BOMB is valid
            return 5
    valid = np.where(mask)[0]
    # Prefer movement over stay
    move_actions = [v for v in valid if v < 4]
    if move_actions:
        return int(np.random.choice(move_actions))
    return int(np.random.choice(valid)) if len(valid) > 0 else 4


def evasive_opponent(obs):
    """Tries to move away from bombs, random otherwise."""
    mask = np.array(obs.get("action_mask", [1]*6), dtype=bool)
    view = np.array(obs.get("agent_viewcone", np.zeros((7,5,25))))
    # If enemy bomb visible, try to move away
    if view.shape[-1] >= 20:
        enemy_bombs = view[..., 18]  # ENEMY_BOMB channel
        if np.any(enemy_bombs > 0):
            # Find safest direction - away from bomb
            bomb_y, bomb_x = np.where(enemy_bombs > 0)
            if len(bomb_y) > 0:
                # Just pick any valid movement action
                move_actions = [v for v in np.where(mask)[0] if v < 4]
                if move_actions:
                    return int(np.random.choice(move_actions))
    valid = np.where(mask)[0]
    return int(np.random.choice(valid)) if len(valid) > 0 else 4


CURRICULUM_STAGES = [
    ("static", static_opponent, 150000),
    ("random", random_valid_opponent, 200000),
    ("simple_bomb", simple_bomb_opponent, 250000),
    ("evasive", evasive_opponent, 200000),
    ("mixed", None, 200000),  # cycles through all
]


class CurriculumEnv(gymnasium.Env):
    """Single-agent env with curriculum opponents."""
    def __init__(self, opponent_fn=None, cfg=None):
        super().__init__()
        self.cfg = cfg or default_config()
        self.cfg.env.render_mode = None
        raw = Bomberman(self.cfg)
        self._parallel_env = aec_to_parallel(raw)
        self.agent_id = "agent_0"
        self._episode_count = 0
        self.action_space = Discrete(6)
        self._last_action_mask = None
        self._obs_size = None
        self._last_obs_dict = None
        self.opponent_fn = opponent_fn or random_valid_opponent
        self._compute_obs_space()

    def _compute_obs_space(self):
        cfg = self.cfg
        vl = int(cfg.dynamics.vision.behind) + int(cfg.dynamics.vision.ahead) + 1
        vw = int(cfg.dynamics.vision.left) + int(cfg.dynamics.vision.right) + 1
        av = vl * vw * 25
        br = int(cfg.entities.base.vision_radius)
        bs = 2 * br + 1
        bv = bs * bs * 25
        self._obs_size = av + bv + 11
        self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self._obs_size,), dtype=np.float32)

    def reset(self, seed=None, options=None):
        self._episode_count += 1
        obs_dict, info_dict = self._parallel_env.reset(seed=self._episode_count, options=options)
        self._last_obs_dict = obs_dict
        self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
        return self._flatten(obs_dict[self.agent_id]), {}

    def step(self, action):
        actions = {self.agent_id: action}
        for aid, obs in self._last_obs_dict.items():
            if aid != self.agent_id:
                actions[aid] = self.opponent_fn(obs)
        obs_dict, rewards, terminations, truncations, infos = self._parallel_env.step(actions)
        self._last_obs_dict = obs_dict
        if self.agent_id not in obs_dict:
            return np.zeros(self._obs_size, dtype=np.float32), 0.0, True, False, {}
        self._last_action_mask = obs_dict[self.agent_id]["action_mask"].astype(bool)
        obs = self._flatten(obs_dict[self.agent_id])
        r = float(rewards.get(self.agent_id, 0.0))
        done = terminations.get(self.agent_id, False) or truncations.get(self.agent_id, False)
        return obs, r, done, False, infos.get(self.agent_id, {})

    def action_masks(self):
        return self._last_action_mask

    def _flatten(self, od):
        return np.concatenate([
            od["agent_viewcone"].flatten(), od["base_viewcone"].flatten(),
            np.array([od["direction"]], dtype=np.float32),
            od["location"].flatten().astype(np.float32),
            od["base_location"].flatten().astype(np.float32),
            od["health"].flatten().astype(np.float32),
            np.array([od["frozen_ticks"]], dtype=np.float32),
            od["base_health"].flatten().astype(np.float32),
            od["team_resources"].flatten().astype(np.float32),
            np.array([od["team_bombs"]], dtype=np.float32),
            np.array([od["step"]], dtype=np.float32),
        ], dtype=np.float32)

    def close(self):
        self._parallel_env.close()


class CurriculumCallback(BaseCallback):
    """Advances curriculum stage based on win rate + pushes checkpoints."""
    def __init__(self, eval_freq=50000, save_freq=50000):
        super().__init__()
        self.eval_freq = eval_freq
        self.save_freq = save_freq
        self.stage_idx = 0
        self.stage_steps = 0
        self.wins_history = []
        self.eval_episodes = 100

    def _on_step(self) -> bool:
        if self.num_timesteps % self.save_freq == 0:
            path = os.path.join(DATA_DIR, f"phase3_ckpt_{self.num_timesteps}.zip")
            self.model.save(path)
            hub_push(path, f"phase3_ckpt_{self.num_timesteps}.zip")

        if self.num_timesteps % self.eval_freq == 0:
            self._evaluate_and_maybe_advance()
        return True

    def _evaluate_and_maybe_advance(self):
        stage_name, opp_fn, stage_limit = CURRICULUM_STAGES[self.stage_idx]
        print(f"\n--- Evaluating at stage {stage_name} (step {self.num_timesteps}) ---")

        # Run eval episodes
        env = CurriculumEnv(opponent_fn=opp_fn, cfg=default_config())
        env = ActionMasker(env, lambda e: e.action_masks())
        wins = 0; total_r = 0
        for ep in range(self.eval_episodes):
            obs, _ = env.reset(seed=ep + 100000 + self.num_timesteps)
            ep_r = 0; done = False
            while not done:
                action, _ = self.model.predict(obs, action_masks=env.action_masks(), deterministic=True)
                obs, r, done, _, _ = env.step(int(action))
                ep_r += r
            total_r += ep_r
            if ep_r > 10:
                wins += 1
        env.close()

        win_rate = wins / self.eval_episodes
        avg_r = total_r / self.eval_episodes
        print(f"  Win rate: {win_rate:.1%}, Avg reward: {avg_r:.1f}")
        self.wins_history.append((self.num_timesteps, stage_name, win_rate, avg_r))

        # Save eval results
        eval_file = f"/app/phase3_eval_{self.num_timesteps}.txt"
        with open(eval_file, "w") as f:
            f.write(f"Stage: {stage_name}\nStep: {self.num_timesteps}\nWinRate: {win_rate:.1%}\nAvgReward: {avg_r:.1f}\n")
        hub_push(eval_file, f"phase3_eval_{self.num_timesteps}.txt")

        # Advance curriculum if win rate > 55% and we've spent enough steps
        if win_rate > 0.55 and self.stage_idx < len(CURRICULUM_STAGES) - 1:
            self.stage_idx += 1
            new_stage = CURRICULUM_STAGES[self.stage_idx][0]
            print(f"  >>> ADVANCING to stage: {new_stage} <<<")


def main():
    print("=" * 60)
    print("PHASE 3: Rule-Based Curriculum")
    print("=" * 60)

    # Download latest checkpoint (phase2_final or best available)
    latest = None
    for ckpt in ["phase2_final.zip", "phase2_ckpt_600352.zip", "phase1_final.zip"]:
        try:
            latest = hf_hub_download(repo_id=HUB_REPO, filename=ckpt, repo_type="model", local_dir=DATA_DIR)
            print(f"Downloaded checkpoint: {ckpt}")
            break
        except Exception:
            pass
    if latest is None:
        raise RuntimeError("No checkpoint found!")

    # Start with first curriculum stage
    stage_name, opp_fn, _ = CURRICULUM_STAGES[0]
    print(f"Starting curriculum stage: {stage_name}")

    cfg = default_config()
    cfg.env.render_mode = None
    base = CurriculumEnv(opponent_fn=opp_fn, cfg=cfg)
    env = ActionMasker(base, lambda e: e.action_masks())
    env = Monitor(env)

    model = MaskablePPO.load(latest, env=env)
    start_ts = model.num_timesteps
    print(f"Loaded model at timestep {start_ts}")

    cb = CurriculumCallback(eval_freq=50000, save_freq=50000)
    model.learn(total_timesteps=1000000, callback=cb, progress_bar=False, reset_num_timesteps=False)

    # Save final
    final = os.path.join(DATA_DIR, "phase3_final.zip")
    model.save(final)
    hub_push(final, "phase3_final.zip")

    # Final eval
    print("\n=== FINAL EVALUATION ===")
    raw = Bomberman(default_config())
    env = aec_to_parallel(raw)
    wins = 0; total_r = 0
    for ep in range(200):
        obs, _ = env.reset(seed=ep + 200000)
        ep_r = 0; done = False
        while not done:
            if "agent_0" not in obs:
                break
            ao = obs["agent_0"]
            mask = np.array(ao.get("action_mask", [1]*6), dtype=bool)
            vec = np.concatenate([
                np.array(ao["agent_viewcone"], np.float32).flatten(),
                np.array(ao["base_viewcone"], np.float32).flatten(),
                np.array([ao["direction"]], np.float32),
                np.array(ao["location"], np.float32).flatten(),
                np.array(ao["base_location"], np.float32).flatten(),
                np.array(ao["health"], np.float32).flatten(),
                np.array([ao["frozen_ticks"]], np.float32),
                np.array(ao["base_health"], np.float32).flatten(),
                np.array(ao["team_resources"], np.float32).flatten(),
                np.array([ao["team_bombs"]], np.float32),
                np.array([ao["step"]], np.float32),
            ], dtype=np.float32)
            action, _ = model.predict(vec, action_masks=mask, deterministic=True)
            acts = {"agent_0": int(action)}
            for aid, o in obs.items():
                if aid != "agent_0":
                    acts[aid] = random_valid_opponent(o)
            obs, rewards, terminations, truncations, _ = env.step(acts)
            ep_r += rewards.get("agent_0", 0)
            done = terminations.get("agent_0", False) or truncations.get("agent_0", False) or "agent_0" not in obs
        total_r += ep_r
        if ep_r > 10:
            wins += 1
    env.close()

    results = (
        f"=== Phase 3 Final Evaluation ===\n"
        f"Episodes: 200\n"
        f"Win Rate: {wins/200:.1%}\n"
        f"Avg Reward: {total_r/200:.1f}\n"
    )
    print(results)
    with open("/app/phase3_final_eval.txt", "w") as f:
        f.write(results)
    hub_push("/app/phase3_final_eval.txt", "phase3_final_eval.txt")
    print("\n✅ PHASE 3 COMPLETE!")


if __name__ == "__main__":
    main()