Krooz commited on
Commit
4de3fe2
·
verified ·
1 Parent(s): b2a6b60

Add model card

Browse files
Files changed (1) hide show
  1. README.md +113 -29
README.md CHANGED
@@ -11,55 +11,139 @@ license: mit
11
  # Pyre PPO Agent — `Krooz/pyre-ppo-agent`
12
 
13
  PPO-trained actor-critic agent for the [Pyre](https://huggingface.co/spaces/Krooz/pyre_env)
14
- fire-evacuation environment, part of the OpenEnv Hackathon (Apr 2026).
15
 
16
- ## Training summary
 
 
 
 
 
 
 
17
 
18
  | Metric | Value |
19
  |--------|-------|
20
- | Total episodes | ? |
21
- | Training time | ? min |
22
- | Final success rate (last 30 ep) | ? |
23
- | Final reward mean (last 30 ep) | ? |
24
- | Curriculum | `?` (patience-gated) |
25
- | Patience threshold | ? |
 
 
 
26
 
27
- ## Hyperparameters
 
 
 
 
 
 
 
28
 
29
  | Param | Value |
30
  |-------|-------|
31
- | Learning rate | `?` |
32
- | PPO clip ε | `?` |
33
- | Entropy coeff | `?` |
34
- | Gamma | `?` |
35
- | Frame stack | `?` |
36
- | Hidden sizes | `?` |
37
- | Device | `?` |
 
 
 
 
38
 
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ## Files in this repository
42
 
43
  | File | Description |
44
  |------|-------------|
45
- | `pyre_ppo.pt` | PyTorch checkpoint (`network_state`, `optimizer_state`, `config`) |
46
- | `pyre_ppo.png` | Training graph reward + success rate over episodes |
47
- | `pyre_ppo.csv` | Per-episode metrics |
48
- | `pyre_ppo_eval.csv` | Per-difficulty evaluation metrics |
49
- | `pyre_ppo_training.log` | Structured JSON-lines training log |
50
 
51
- ## Loading the checkpoint
52
 
53
  ```python
 
54
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- ckpt = torch.load("pyre_ppo.pt", map_location="cpu", weights_only=False)
57
- # ckpt keys: network_state, optimizer_state, scheduler_state, episode, config
58
- print(ckpt["config"]) # input_dim, action_dim, hidden_sizes, history_length, obs_mode
59
  ```
60
 
61
- ## Environment
62
 
63
- - **Space**: [Krooz/pyre_env](https://huggingface.co/spaces/Krooz/pyre_env)
64
- - **Training notebook**: [Google Colab](https://colab.research.google.com/drive/1ojC55qKXMVRXdjKeG5dUHiA5RBOBxA9V?usp=sharing)
65
- - **Source**: [pyre_env/training/ppo/train_torch_ppo.py](training/ppo/train_torch_ppo.py)
 
 
 
11
  # Pyre PPO Agent — `Krooz/pyre-ppo-agent`
12
 
13
  PPO-trained actor-critic agent for the [Pyre](https://huggingface.co/spaces/Krooz/pyre_env)
14
+ fire-evacuation environment (OpenEnv Hackathon, Apr 2026).
15
 
16
+ > ⚠️ This is a raw PyTorch checkpoint, **not** a `transformers` model.
17
+ > The Hugging Face hosted Inference API cannot run it directly.
18
+ > Use the inference code below to load and run it locally.
19
+
20
+ ## Training summary (artifact run: ``pyre_ppo_fixed``)
21
+
22
+ Values below are from ``artifacts/pyre_ppo_fixed.csv``, ``pyre_ppo_fixed_eval.csv``,
23
+ and ``artifacts/pyre_ppo_fixed_training.log`` (HTTP trainer, env server at ``http://localhost:8000``).
24
 
25
  | Metric | Value |
26
  |--------|-------|
27
+ | Total episodes | **200** |
28
+ | Wall-clock training time | **~48 s** (~4.2 eps/s on CPU) |
29
+ | Final success rate (rolling last 30 ep) | **80%** |
30
+ | Final reward mean (rolling last 30 ep) | **+8.446** |
31
+ | Curriculum | **Static** ``easy,medium`` (≈100 eps each; ``--patience-threshold 0``) |
32
+ | Eval cadence | Every **20** episodes, **3** deterministic rollouts |
33
+ | Eval difficulty | **medium** (per eval log / ``pyre_ppo_fixed_eval.csv``) |
34
+
35
+ ## Network architecture (from training log)
36
 
37
+ | Property | Value |
38
+ |----------|-------|
39
+ | Total parameters | **12,065,650** |
40
+ | Input vector dim | **23,140** (encoder ``base_dim`` 5785 × **4** stacked frames) |
41
+ | Action dim | **41** (4 move + 4 look + 1 wait + 16 door open + 16 door close) |
42
+ | Hidden MLP | **512 → 256 → 128** |
43
+
44
+ ## Hyperparameters (defaults matching this run)
45
 
46
  | Param | Value |
47
  |-------|-------|
48
+ | Learning rate | **3×10⁻⁴** |
49
+ | PPO clip ε | **0.2** |
50
+ | Entropy coeff | **0.03** |
51
+ | Value coeff | **0.5** |
52
+ | Gamma | **0.99** |
53
+ | GAE λ | **0.95** |
54
+ | PPO update every | **5** episodes |
55
+ | PPO epochs / minibatch | **4** / **256** |
56
+ | Max grad norm | **0.5** |
57
+ | Observation mode | **visible** (partial observability) |
58
+ | Device | **cpu** |
59
 
60
+ ### Evaluation checkpoints (from ``pyre_ppo_fixed_eval.csv``)
61
 
62
+ | Episode | Difficulty | Success rate | Reward mean | Steps mean |
63
+ |---------|------------|--------------|-------------|------------|
64
+ | 20 | medium | 100% | +15.698 | 7.0 |
65
+ | 40 | medium | 100% | +15.640 | 4.3 |
66
+ | 60 | medium | 100% | +16.887 | 9.0 |
67
+ | 80 | medium | 100% | +15.162 | 10.3 |
68
+ | 100 | medium | 67% | +6.008 | 57.0 |
69
+ | 120 | medium | 67% | +6.401 | 32.7 |
70
+ | 140 | medium | 100% | +16.283 | 6.3 |
71
+ | 160 | medium | 100% | +16.573 | 8.3 |
72
+ | 180 | medium | 100% | +16.397 | 8.0 |
73
+ | 200 | medium | 67% | +6.807 | 14.7 |
74
 
75
  ## Files in this repository
76
 
77
  | File | Description |
78
  |------|-------------|
79
+ | `model.pt` | PyTorch checkpoint (`network_state`, `optimizer_state`, `scheduler_state`, `args`, `episode`) |
80
+ | `training_graph.png` | Training curves (reward + success rate vs episode) |
81
+ | `episode_metrics.csv` | Per-episode training metrics |
82
+ | `eval_metrics.csv` | Periodic eval aggregates |
83
+ | `training.log` | Full console transcript of the HTTP training run |
84
 
85
+ ## Running inference locally
86
 
87
  ```python
88
+ import sys
89
  import torch
90
+ from huggingface_hub import hf_hub_download
91
+
92
+ # 1. Point Python at your local pyre_env checkout (or install the package)
93
+ sys.path.insert(0, "pyre_env")
94
+
95
+ from training.ppo.train_torch_ppo import (
96
+ ActorCritic,
97
+ ObservationEncoder,
98
+ action_index_to_env_action,
99
+ build_action_mask,
100
+ )
101
+
102
+ # 2. Download the checkpoint from this Hub repo
103
+ ckpt_path = hf_hub_download(repo_id="Krooz/pyre-ppo-agent", filename="model.pt")
104
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
105
+
106
+ # 3. Rebuild the policy from saved training args
107
+ saved_args = ckpt["args"]
108
+ encoder = ObservationEncoder(mode=saved_args.get("observation_mode", "visible"))
109
+ hidden_sizes = tuple(int(x) for x in saved_args.get("hidden_sizes", "512,256,128").split(","))
110
+ history_length = saved_args.get("history_length", 4)
111
+ input_dim = encoder.base_dim * history_length
112
+ network = ActorCritic(input_dim, 41, hidden_sizes)
113
+ network.load_state_dict(ckpt["network_state"])
114
+ network.eval()
115
+ print(f"Loaded checkpoint from episode {ckpt.get('episode', '?')}")
116
+
117
+ # 4. Roll out one episode (in-process env — swap for HTTP client if you prefer)
118
+ from openenv_pyre import PyreEnvironment
119
+ from collections import deque
120
+ import numpy as np
121
+
122
+ env = PyreEnvironment()
123
+ obs = env.reset(difficulty="medium")
124
+ frames = deque([np.zeros(encoder.base_dim, dtype=np.float32)] * history_length, maxlen=history_length)
125
+ frames.append(encoder.encode(obs))
126
+
127
+ total_reward = 0.0
128
+ with torch.no_grad():
129
+ while True:
130
+ state_vec = np.concatenate(list(frames), dtype=np.float32)
131
+ obs_t = torch.tensor(state_vec, dtype=torch.float32).unsqueeze(0)
132
+ mask_t = torch.tensor(build_action_mask(obs, exclude_look=True), dtype=torch.float32).unsqueeze(0)
133
+ action_t, _, _ = network.act(obs_t, mask_t, deterministic=True)
134
+ obs = env.step(action_index_to_env_action(int(action_t.item())))
135
+ total_reward += float(obs.reward or 0.0)
136
+ frames.append(encoder.encode(obs))
137
+ if obs.done:
138
+ break
139
 
140
+ print(f"Episode finished — evacuated={obs.agent_evacuated} reward={total_reward:.3f}")
 
 
141
  ```
142
 
143
+ ## Environment & training resources
144
 
145
+ - **HF Space (live env)**: [Krooz/pyre_env](https://huggingface.co/spaces/Krooz/pyre_env)
146
+ - **PPO training in Colab (HTTP to Space)**: [Pyre PPO training — Google Colab](https://colab.research.google.com/drive/1ojC55qKXMVRXdjKeG5dUHiA5RBOBxA9V?usp=sharing)
147
+ - **Local HTTP trainer**: ``training/ppo/train_torch_ppo_http.py``
148
+ - **Local in-process trainer**: ``training/ppo/train_torch_ppo.py``
149
+ - **Notebook source**: ``training/ppo/pyre_ppo_training.ipynb``