File size: 3,179 Bytes
4105e6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""AE Manager - loads trained MaskablePPO and returns actions for Bomberman."""

import os
import sys
import numpy as np
from sb3_contrib import MaskablePPO

# Try to find til_environment (for default_config/obs shape if needed)
for p in [
    os.path.join(os.path.dirname(__file__), "..", "til-26-ae"),
    "/app/til-26-ae-repo/til-26-ae",
    "til-26-ae",
]:
    if os.path.isdir(p) and os.path.isfile(os.path.join(p, "til_environment", "bomberman_env.py")):
        sys.path.insert(0, p)
        break


class AEManager:
    """Loads a trained MaskablePPO model and serves inference for Bomberman."""

    def __init__(self):
        self.model = None
        self._obs_size = None
        # Try loading from several locations
        candidates = [
            os.environ.get("MODEL_PATH", ""),
            os.path.join(os.path.dirname(__file__), "..", "phase1_final.zip"),
            os.path.join(os.path.dirname(__file__), "..", "phase3_final.zip"),
            "/app/data/phase3_final.zip",
            "/app/data/phase2_final.zip",
            "/app/data/phase1_final.zip",
        ]
        for path in candidates:
            if path and os.path.isfile(path):
                try:
                    self.model = MaskablePPO.load(path)
                    print(f"[AE Manager] Loaded model from {path}")
                    break
                except Exception as e:
                    print(f"[AE Manager] Failed to load {path}: {e}")
        if self.model is None:
            print("[AE Manager] No trained model found -- will return random valid actions.")

    @staticmethod
    def _flatten_obs(obs_dict):
        """Flatten observation dict into the vector used during training."""
        return np.concatenate([
            np.array(obs_dict["agent_viewcone"]).flatten(),
            np.array(obs_dict["base_viewcone"]).flatten(),
            np.array([obs_dict["direction"]], dtype=np.float32),
            np.array(obs_dict["location"]).flatten().astype(np.float32),
            np.array(obs_dict["base_location"]).flatten().astype(np.float32),
            np.array(obs_dict["health"]).flatten().astype(np.float32),
            np.array([obs_dict["frozen_ticks"]], dtype=np.float32),
            np.array(obs_dict["base_health"]).flatten().astype(np.float32),
            np.array(obs_dict["team_resources"]).flatten().astype(np.float32),
            np.array([obs_dict["team_bombs"]], dtype=np.float32),
            np.array([obs_dict["step"]], dtype=np.float32),
        ], dtype=np.float32)

    def ae(self, observation: dict) -> int:
        """Get action from observation dict."""
        if self.model is None:
            # Fallback: random valid action
            mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)
            valid = np.where(mask)[0]
            return int(np.random.choice(valid)) if len(valid) > 0 else 4

        obs_vec = self._flatten_obs(observation)
        action_mask = np.array(observation.get("action_mask", [1]*6), dtype=bool)

        action, _ = self.model.predict(
            obs_vec,
            action_masks=action_mask,
            deterministic=True,
        )
        return int(action)