yxc20098 commited on
Commit
9063f2b
·
1 Parent(s): aa68464

Vendor the necessary openra_rl_training + openra_env modules

Browse files

Makes the bench self-contained for Python code — no external
OpenRA-RL-Training / OpenRA-RL checkouts needed (addresses the other
half of PR #12's concern, the correct way: faithful frozen copies, not
stubs).

Vendored verbatim (sha-verified identical to source) — only the
modules the bench actually uses:

openra_rl_training/
scenario.py — ScenarioDefinition, VALID_ACTOR_TYPES
training/reward_funcs.py — DEFAULT_REWARD_WEIGHTS
training/rust_env_pool.py — RustEnvPool
training/minimap_renderer.py — render_minimap (terrain minimap)
openra_env/
game_data.py — RA_UNITS / RA_BUILDINGS

openra_env/__init__.py deliberately does NOT pull client/models — the
only reference to those is example code inside a docstring; the bench
never imports the legacy gRPC client at runtime.

requirements.txt: add the now-needed deps (pydantic, pyyaml, pillow,
numpy, matplotlib). VENDOR.md documents provenance + the re-vendor
procedure. The Rust engine (openra_train) still builds from OpenRA-Rust
via maturin — a compiled extension can't be vendored as source.

Verified: the vendored packages resolve to the in-repo copies (not the
sibling checkouts) and DEFAULT_REWARD_WEIGHTS.outcome == 0.5 (the real
value; PR #12's stub had a wrong 0.2).

VENDOR.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vendored dependencies
2
+
3
+ OpenRA-Bench vendors a small, faithful subset of two sibling repos so the
4
+ benchmark runs without external source checkouts, and so the evaluation
5
+ stack is frozen for reproducibility.
6
+
7
+ ## `openra_rl_training/` — from OpenRA-RL-Training
8
+ - `scenario.py` — `ScenarioDefinition`, `VALID_ACTOR_TYPES`
9
+ - `training/reward_funcs.py` — `DEFAULT_REWARD_WEIGHTS` (composite scorer)
10
+ - `training/rust_env_pool.py` — `RustEnvPool` (wraps the engine)
11
+ - `training/minimap_renderer.py` — `render_minimap` (terrain minimap)
12
+
13
+ ## `openra_env/` — from OpenRA-RL
14
+ - `game_data.py` — `RA_UNITS` / `RA_BUILDINGS` (consumed by `scenario.py`)
15
+
16
+ These are **verbatim copies** — do not hand-edit. To update: re-copy from
17
+ the source repos and re-run the full suite (`pytest tests/`).
18
+
19
+ ## NOT vendored
20
+ - **`openra_train`** — the Rust engine, a compiled extension. Build it
21
+ from the OpenRA-Rust repo: `maturin develop --release`. It cannot be
22
+ vendored as source.
23
+ - **`openra_env.client` / `openra_env.models`** — the legacy gRPC client.
24
+ The bench never imports them at runtime (the only reference is example
25
+ code inside a docstring), so they are intentionally left out.
openra_env/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Vendored subset of OpenRA-RL's `openra_env` — see VENDOR.md.
2
+
3
+ Only `game_data` (RA_UNITS / RA_BUILDINGS, consumed by the vendored
4
+ `openra_rl_training.scenario`) is vendored. The gRPC `client` / `models`
5
+ are deliberately NOT vendored — the bench never imports them at runtime
6
+ (the one reference is example code inside a docstring).
7
+ """
openra_env/game_data.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Static Red Alert mod data for game knowledge tools.
2
+
3
+ Provides unit stats, building stats, tech tree, and faction information
4
+ extracted from OpenRA Red Alert mod rules. This gives an LLM agent the same
5
+ reference knowledge a human player would have from experience.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+
11
+ # ─── Unit Data ────────────────────────────────────────────────────────────────
12
+
13
+ RA_UNITS: dict[str, dict] = {
14
+ # Infantry
15
+ "e1": {
16
+ "name": "Rifle Infantry",
17
+ "category": "infantry",
18
+ "cost": 100,
19
+ "hp": 5000,
20
+ "speed": 56,
21
+ "armor": "none",
22
+ "side": "both",
23
+ "prerequisites": ["barr|tent"],
24
+ "description": "Basic infantry unit. Cheap and fast to produce.",
25
+ },
26
+ "e2": {
27
+ "name": "Grenadier",
28
+ "category": "infantry",
29
+ "cost": 150,
30
+ "hp": 5000,
31
+ "speed": 56,
32
+ "armor": "none",
33
+ "side": "both",
34
+ "prerequisites": ["barr|tent"],
35
+ "description": "Anti-structure infantry. Grenades deal area damage.",
36
+ },
37
+ "e3": {
38
+ "name": "Rocket Soldier",
39
+ "category": "infantry",
40
+ "cost": 300,
41
+ "hp": 4500,
42
+ "speed": 56,
43
+ "armor": "none",
44
+ "side": "both",
45
+ "prerequisites": ["barr|tent"],
46
+ "description": "Anti-armor and anti-air infantry.",
47
+ },
48
+ "e4": {
49
+ "name": "Flamethrower",
50
+ "category": "infantry",
51
+ "cost": 300,
52
+ "hp": 4000,
53
+ "speed": 56,
54
+ "armor": "none",
55
+ "side": "soviet",
56
+ "prerequisites": ["barr", "ftur"],
57
+ "description": "Short-range anti-infantry/structure. Soviet only.",
58
+ },
59
+ "e6": {
60
+ "name": "Engineer",
61
+ "category": "infantry",
62
+ "cost": 400,
63
+ "hp": 4000,
64
+ "speed": 56,
65
+ "armor": "none",
66
+ "side": "both",
67
+ "prerequisites": ["barr|tent"],
68
+ "description": "Captures enemy buildings. Cannot attack.",
69
+ },
70
+ "e7": {
71
+ "name": "Tanya",
72
+ "category": "infantry",
73
+ "cost": 1800,
74
+ "hp": 10000,
75
+ "speed": 68,
76
+ "armor": "none",
77
+ "side": "allied",
78
+ "prerequisites": ["tent", "atek"],
79
+ "build_limit": 1,
80
+ "description": "Elite commando. Destroys buildings with C4, kills infantry instantly. Allied only.",
81
+ },
82
+ "medi": {
83
+ "name": "Medic",
84
+ "category": "infantry",
85
+ "cost": 200,
86
+ "hp": 6000,
87
+ "speed": 49,
88
+ "armor": "none",
89
+ "side": "allied",
90
+ "prerequisites": ["tent"],
91
+ "description": "Heals nearby infantry. Cannot attack.",
92
+ },
93
+ "mech": {
94
+ "name": "Mechanic",
95
+ "category": "infantry",
96
+ "cost": 500,
97
+ "hp": 8000,
98
+ "speed": 49,
99
+ "armor": "none",
100
+ "side": "allied",
101
+ "prerequisites": ["tent", "fix"],
102
+ "description": "Repairs nearby vehicles. Cannot attack.",
103
+ },
104
+ "spy": {
105
+ "name": "Spy",
106
+ "category": "infantry",
107
+ "cost": 500,
108
+ "hp": 2500,
109
+ "speed": 56,
110
+ "armor": "none",
111
+ "side": "allied",
112
+ "prerequisites": ["tent", "dome"],
113
+ "description": "Disguises as enemy infantry. Infiltrates buildings for bonuses.",
114
+ },
115
+ "thf": {
116
+ "name": "Thief",
117
+ "category": "infantry",
118
+ "cost": 500,
119
+ "hp": 5000,
120
+ "speed": 68,
121
+ "armor": "none",
122
+ "side": "allied",
123
+ "prerequisites": ["tent", "dome"],
124
+ "description": "Steals credits from enemy refineries.",
125
+ },
126
+ "shok": {
127
+ "name": "Shock Trooper",
128
+ "category": "infantry",
129
+ "cost": 350,
130
+ "hp": 5000,
131
+ "speed": 49,
132
+ "armor": "none",
133
+ "side": "soviet",
134
+ "prerequisites": ["barr", "stek", "tsla"],
135
+ "description": "Tesla infantry. High damage vs all targets. Soviet only.",
136
+ },
137
+ "dog": {
138
+ "name": "Attack Dog",
139
+ "category": "infantry",
140
+ "cost": 200,
141
+ "hp": 2000,
142
+ "speed": 99,
143
+ "armor": "none",
144
+ "side": "soviet",
145
+ "prerequisites": ["kenn"],
146
+ "description": "Fast anti-infantry unit. Kills spies. Soviet only.",
147
+ },
148
+
149
+ # Vehicles
150
+ "1tnk": {
151
+ "name": "Light Tank",
152
+ "category": "vehicle",
153
+ "cost": 700,
154
+ "hp": 23000,
155
+ "speed": 113,
156
+ "armor": "heavy",
157
+ "side": "allied",
158
+ "prerequisites": ["weap"],
159
+ "description": "Fast medium tank. Good all-around. Allied only.",
160
+ },
161
+ "2tnk": {
162
+ "name": "Medium Tank",
163
+ "category": "vehicle",
164
+ "cost": 850,
165
+ "hp": 30000,
166
+ "speed": 72,
167
+ "armor": "heavy",
168
+ "side": "allied",
169
+ "prerequisites": ["weap", "fix"],
170
+ "description": "Main battle tank. Balanced stats. Allied only. Requires Repair Facility.",
171
+ },
172
+ "3tnk": {
173
+ "name": "Heavy Tank",
174
+ "category": "vehicle",
175
+ "cost": 1150,
176
+ "hp": 46000,
177
+ "speed": 64,
178
+ "armor": "heavy",
179
+ "side": "soviet",
180
+ "prerequisites": ["weap", "fix"],
181
+ "description": "Powerful main battle tank. Dual cannons. Soviet only. Requires Repair Facility.",
182
+ },
183
+ "4tnk": {
184
+ "name": "Mammoth Tank",
185
+ "category": "vehicle",
186
+ "cost": 2000,
187
+ "hp": 60000,
188
+ "speed": 43,
189
+ "armor": "heavy",
190
+ "side": "soviet",
191
+ "prerequisites": ["weap", "fix", "stek"],
192
+ "description": "Heaviest tank. Dual cannons + missiles. Self-healing. Soviet only.",
193
+ },
194
+ "v2rl": {
195
+ "name": "V2 Rocket Launcher",
196
+ "category": "vehicle",
197
+ "cost": 900,
198
+ "hp": 15000,
199
+ "speed": 72,
200
+ "armor": "light",
201
+ "side": "soviet",
202
+ "prerequisites": ["weap", "dome"],
203
+ "description": "Long-range artillery. High damage, inaccurate. Soviet only.",
204
+ },
205
+ "jeep": {
206
+ "name": "Ranger",
207
+ "category": "vehicle",
208
+ "cost": 500,
209
+ "hp": 15000,
210
+ "speed": 164,
211
+ "armor": "light",
212
+ "side": "allied",
213
+ "prerequisites": ["weap"],
214
+ "description": "Fast scout vehicle with machine gun. Allied only.",
215
+ },
216
+ "apc": {
217
+ "name": "APC",
218
+ "category": "vehicle",
219
+ "cost": 850,
220
+ "hp": 20000,
221
+ "speed": 128,
222
+ "armor": "heavy",
223
+ "side": "soviet",
224
+ "prerequisites": ["weap"],
225
+ "description": "Armored troop transport. Carries 5 infantry. Soviet only.",
226
+ },
227
+ "arty": {
228
+ "name": "Artillery",
229
+ "category": "vehicle",
230
+ "cost": 850,
231
+ "hp": 7500,
232
+ "speed": 54,
233
+ "armor": "light",
234
+ "side": "allied",
235
+ "prerequisites": ["weap", "dome"],
236
+ "description": "Long-range siege weapon. Allied only.",
237
+ },
238
+ "harv": {
239
+ "name": "Ore Truck",
240
+ "category": "vehicle",
241
+ "cost": 1100,
242
+ "hp": 60000,
243
+ "speed": 72,
244
+ "armor": "heavy",
245
+ "side": "both",
246
+ "prerequisites": ["proc"],
247
+ "description": "Harvests ore and delivers to refinery. Free with refinery.",
248
+ },
249
+ "mcv": {
250
+ "name": "MCV",
251
+ "category": "vehicle",
252
+ "cost": 2000,
253
+ "hp": 60000,
254
+ "speed": 60,
255
+ "armor": "light",
256
+ "side": "both",
257
+ "prerequisites": ["weap", "fix"],
258
+ "description": "Deploys into Construction Yard. Mobile base.",
259
+ },
260
+ "ftrk": {
261
+ "name": "Flak Truck",
262
+ "category": "vehicle",
263
+ "cost": 600,
264
+ "hp": 15000,
265
+ "speed": 113,
266
+ "armor": "light",
267
+ "side": "soviet",
268
+ "prerequisites": ["weap"],
269
+ "description": "Mobile anti-air unit. Soviet only.",
270
+ },
271
+ "mnly": {
272
+ "name": "Minelayer",
273
+ "category": "vehicle",
274
+ "cost": 800,
275
+ "hp": 30000,
276
+ "speed": 113,
277
+ "armor": "heavy",
278
+ "side": "both",
279
+ "prerequisites": ["weap", "fix"],
280
+ "description": "Lays anti-tank mines.",
281
+ },
282
+ "ttnk": {
283
+ "name": "Tesla Tank",
284
+ "category": "vehicle",
285
+ "cost": 1350,
286
+ "hp": 30000,
287
+ "speed": 92,
288
+ "armor": "light",
289
+ "side": "soviet",
290
+ "prerequisites": ["weap", "stek", "tsla"],
291
+ "description": "Tesla weapon on tracks. Effective vs all targets. Soviet only.",
292
+ },
293
+ "ctnk": {
294
+ "name": "Chrono Tank",
295
+ "category": "vehicle",
296
+ "cost": 1350,
297
+ "hp": 20000,
298
+ "speed": 86,
299
+ "armor": "light",
300
+ "side": "allied",
301
+ "prerequisites": ["weap", "atek"],
302
+ "description": "Teleporting tank. Hit and run tactics. Allied only.",
303
+ },
304
+ "stnk": {
305
+ "name": "Phase Transport",
306
+ "category": "vehicle",
307
+ "cost": 1000,
308
+ "hp": 11000,
309
+ "speed": 128,
310
+ "armor": "light",
311
+ "side": "allied",
312
+ "prerequisites": ["weap", "atek"],
313
+ "description": "Cloaked APC. Invisible when not firing. Allied only.",
314
+ },
315
+ "qtnk": {
316
+ "name": "MAD Tank",
317
+ "category": "vehicle",
318
+ "cost": 2000,
319
+ "hp": 22000,
320
+ "speed": 46,
321
+ "armor": "heavy",
322
+ "side": "soviet",
323
+ "prerequisites": ["weap", "stek"],
324
+ "description": "Deploys seismic charge, destroying self and nearby vehicles. Soviet only.",
325
+ },
326
+ "dtrk": {
327
+ "name": "Demolition Truck",
328
+ "category": "vehicle",
329
+ "cost": 2500,
330
+ "hp": 11000,
331
+ "speed": 113,
332
+ "armor": "light",
333
+ "side": "soviet",
334
+ "prerequisites": ["weap", "stek"],
335
+ "description": "Suicide vehicle. Massive area nuclear explosion on death. Soviet only.",
336
+ },
337
+ "mgg": {
338
+ "name": "Mobile Gap Generator",
339
+ "category": "vehicle",
340
+ "cost": 1000,
341
+ "hp": 11000,
342
+ "speed": 72,
343
+ "armor": "heavy",
344
+ "side": "allied",
345
+ "prerequisites": ["weap", "atek"],
346
+ "description": "Creates mobile shroud area. Allied only.",
347
+ },
348
+ "mrj": {
349
+ "name": "Mobile Radar Jammer",
350
+ "category": "vehicle",
351
+ "cost": 1000,
352
+ "hp": 11000,
353
+ "speed": 68,
354
+ "armor": "heavy",
355
+ "side": "allied",
356
+ "prerequisites": ["weap", "atek"],
357
+ "description": "Jams enemy radar in area. Allied only.",
358
+ },
359
+ "truk": {
360
+ "name": "Supply Truck",
361
+ "category": "vehicle",
362
+ "cost": 500,
363
+ "hp": 11000,
364
+ "speed": 113,
365
+ "armor": "light",
366
+ "side": "both",
367
+ "prerequisites": ["weap"],
368
+ "description": "Delivers cash when reaching allied structures.",
369
+ },
370
+
371
+ # Aircraft
372
+ "heli": {
373
+ "name": "Longbow",
374
+ "category": "aircraft",
375
+ "cost": 2000,
376
+ "hp": 12000,
377
+ "speed": 149,
378
+ "armor": "light",
379
+ "side": "allied",
380
+ "prerequisites": ["hpad"],
381
+ "description": "Anti-armor helicopter with missiles. Allied only.",
382
+ },
383
+ "hind": {
384
+ "name": "Hind",
385
+ "category": "aircraft",
386
+ "cost": 1500,
387
+ "hp": 12000,
388
+ "speed": 112,
389
+ "armor": "light",
390
+ "side": "soviet",
391
+ "prerequisites": ["afld"],
392
+ "description": "Anti-ground attack helicopter. Soviet only.",
393
+ },
394
+ "mh60": {
395
+ "name": "Black Hawk",
396
+ "category": "aircraft",
397
+ "cost": 1500,
398
+ "hp": 12000,
399
+ "speed": 112,
400
+ "armor": "light",
401
+ "side": "allied",
402
+ "prerequisites": ["hpad"],
403
+ "description": "Transport/attack helicopter. Allied only.",
404
+ },
405
+ "tran": {
406
+ "name": "Chinook",
407
+ "category": "aircraft",
408
+ "cost": 900,
409
+ "hp": 14000,
410
+ "speed": 128,
411
+ "armor": "light",
412
+ "side": "both",
413
+ "prerequisites": ["hpad|afld"],
414
+ "description": "Transport helicopter. Carries 5 infantry.",
415
+ },
416
+ "yak": {
417
+ "name": "Yak",
418
+ "category": "aircraft",
419
+ "cost": 1350,
420
+ "hp": 6000,
421
+ "speed": 178,
422
+ "armor": "light",
423
+ "side": "soviet",
424
+ "prerequisites": ["afld"],
425
+ "description": "Fast anti-infantry attack plane. Soviet only.",
426
+ },
427
+ "mig": {
428
+ "name": "MiG",
429
+ "category": "aircraft",
430
+ "cost": 2000,
431
+ "hp": 8000,
432
+ "speed": 223,
433
+ "armor": "light",
434
+ "side": "soviet",
435
+ "prerequisites": ["afld", "stek"],
436
+ "description": "Anti-structure/armor attack plane with missiles. Soviet only.",
437
+ },
438
+
439
+ # Ships
440
+ "ss": {
441
+ "name": "Submarine",
442
+ "category": "ship",
443
+ "cost": 950,
444
+ "hp": 25000,
445
+ "speed": 78,
446
+ "armor": "light",
447
+ "side": "soviet",
448
+ "prerequisites": ["spen"],
449
+ "description": "Invisible anti-ship unit. Soviet only.",
450
+ },
451
+ "dd": {
452
+ "name": "Destroyer",
453
+ "category": "ship",
454
+ "cost": 1000,
455
+ "hp": 40000,
456
+ "speed": 92,
457
+ "armor": "heavy",
458
+ "side": "allied",
459
+ "prerequisites": ["syrd", "dome"],
460
+ "description": "Multi-role warship. Anti-sub, anti-air, anti-surface. Allied only.",
461
+ },
462
+ "ca": {
463
+ "name": "Cruiser",
464
+ "category": "ship",
465
+ "cost": 2400,
466
+ "hp": 80000,
467
+ "speed": 44,
468
+ "armor": "heavy",
469
+ "side": "allied",
470
+ "prerequisites": ["syrd", "atek"],
471
+ "description": "Heavy bombardment ship. Long range. Allied only.",
472
+ },
473
+ "pt": {
474
+ "name": "Gunboat",
475
+ "category": "ship",
476
+ "cost": 500,
477
+ "hp": 20000,
478
+ "speed": 142,
479
+ "armor": "heavy",
480
+ "side": "both",
481
+ "prerequisites": ["syrd|spen"],
482
+ "description": "Fast patrol boat.",
483
+ },
484
+ "lst": {
485
+ "name": "Transport",
486
+ "category": "ship",
487
+ "cost": 500,
488
+ "hp": 40000,
489
+ "speed": 115,
490
+ "armor": "heavy",
491
+ "side": "both",
492
+ "prerequisites": ["syrd|spen"],
493
+ "description": "Naval transport. Carries vehicles and infantry.",
494
+ },
495
+ "msub": {
496
+ "name": "Missile Submarine",
497
+ "category": "ship",
498
+ "cost": 2000,
499
+ "hp": 40000,
500
+ "speed": 44,
501
+ "armor": "light",
502
+ "side": "soviet",
503
+ "prerequisites": ["spen", "stek"],
504
+ "description": "Long-range missile submarine. Soviet only.",
505
+ },
506
+ }
507
+
508
+
509
+ # ─── Building Data ────────────────────────────────────────────────────────────
510
+
511
+ RA_BUILDINGS: dict[str, dict] = {
512
+ "fact": {
513
+ "name": "Construction Yard",
514
+ "cost": 2000,
515
+ "hp": 150000,
516
+ "power": 0,
517
+ "side": "both",
518
+ "prerequisites": [],
519
+ "produces": ["Building", "Defense"],
520
+ "description": "Primary base structure. Required to build other structures.",
521
+ },
522
+ "powr": {
523
+ "name": "Power Plant",
524
+ "cost": 300,
525
+ "hp": 40000,
526
+ "power": 100,
527
+ "side": "both",
528
+ "prerequisites": [],
529
+ "produces": [],
530
+ "description": "Basic power supply. Most structures need power to function.",
531
+ },
532
+ "apwr": {
533
+ "name": "Advanced Power Plant",
534
+ "cost": 500,
535
+ "hp": 70000,
536
+ "power": 200,
537
+ "side": "both",
538
+ "prerequisites": ["dome"],
539
+ "produces": [],
540
+ "description": "Double power output. Requires radar dome tech.",
541
+ },
542
+ "barr": {
543
+ "name": "Soviet Barracks",
544
+ "cost": 500,
545
+ "hp": 60000,
546
+ "power": -20,
547
+ "side": "soviet",
548
+ "prerequisites": ["powr"],
549
+ "produces": ["Infantry"],
550
+ "description": "Soviet infantry production. Required for all Soviet infantry.",
551
+ },
552
+ "tent": {
553
+ "name": "Allied Barracks",
554
+ "cost": 500,
555
+ "hp": 60000,
556
+ "power": -20,
557
+ "side": "allied",
558
+ "prerequisites": ["powr"],
559
+ "produces": ["Infantry"],
560
+ "description": "Allied infantry production. Required for all Allied infantry.",
561
+ },
562
+ "proc": {
563
+ "name": "Ore Refinery",
564
+ "cost": 1400,
565
+ "hp": 90000,
566
+ "power": -30,
567
+ "side": "both",
568
+ "prerequisites": ["powr"],
569
+ "produces": [],
570
+ "description": "Processes ore into credits. Comes with a free Ore Truck.",
571
+ },
572
+ "weap": {
573
+ "name": "War Factory",
574
+ "cost": 2000,
575
+ "hp": 150000,
576
+ "power": -30,
577
+ "side": "both",
578
+ "prerequisites": ["proc"],
579
+ "produces": ["Vehicle"],
580
+ "description": "Vehicle production facility. Required for all vehicles.",
581
+ },
582
+ "dome": {
583
+ "name": "Radar Dome",
584
+ "cost": 1500,
585
+ "hp": 100000,
586
+ "power": -40,
587
+ "side": "both",
588
+ "prerequisites": ["proc"],
589
+ "produces": [],
590
+ "description": "Provides minimap radar. Unlocks advanced tech.",
591
+ },
592
+ "fix": {
593
+ "name": "Service Depot",
594
+ "cost": 1200,
595
+ "hp": 80000,
596
+ "power": -30,
597
+ "side": "both",
598
+ "prerequisites": ["weap"],
599
+ "produces": [],
600
+ "description": "Repairs vehicles. Unlocks MCV and Minelayer.",
601
+ },
602
+ "atek": {
603
+ "name": "Allied Tech Center",
604
+ "cost": 1500,
605
+ "hp": 60000,
606
+ "power": -200,
607
+ "side": "allied",
608
+ "prerequisites": ["dome", "weap"],
609
+ "produces": [],
610
+ "description": "Unlocks advanced Allied units. GPS satellite.",
611
+ },
612
+ "stek": {
613
+ "name": "Soviet Tech Center",
614
+ "cost": 1500,
615
+ "hp": 80000,
616
+ "power": -100,
617
+ "side": "soviet",
618
+ "prerequisites": ["dome", "weap"],
619
+ "produces": [],
620
+ "description": "Unlocks advanced Soviet units.",
621
+ },
622
+ "hpad": {
623
+ "name": "Helipad",
624
+ "cost": 500,
625
+ "hp": 80000,
626
+ "power": -10,
627
+ "side": "allied",
628
+ "prerequisites": ["dome"],
629
+ "produces": ["Aircraft"],
630
+ "description": "Allied aircraft production. Rearming pad.",
631
+ },
632
+ "afld": {
633
+ "name": "Airfield",
634
+ "cost": 500,
635
+ "hp": 100000,
636
+ "power": -20,
637
+ "side": "soviet",
638
+ "prerequisites": ["dome"],
639
+ "produces": ["Aircraft"],
640
+ "description": "Soviet aircraft production. Rearming strip.",
641
+ },
642
+ "spen": {
643
+ "name": "Sub Pen",
644
+ "cost": 800,
645
+ "hp": 100000,
646
+ "power": -20,
647
+ "side": "soviet",
648
+ "prerequisites": ["powr"],
649
+ "produces": ["Ship"],
650
+ "terrain": "water",
651
+ "description": "Soviet naval production. Repairs ships. REQUIRES WATER — cannot build on land maps.",
652
+ },
653
+ "syrd": {
654
+ "name": "Naval Yard",
655
+ "cost": 1000,
656
+ "hp": 100000,
657
+ "power": -20,
658
+ "side": "allied",
659
+ "prerequisites": ["powr"],
660
+ "produces": ["Ship"],
661
+ "terrain": "water",
662
+ "description": "Allied naval production. Repairs ships. REQUIRES WATER — cannot build on land maps.",
663
+ },
664
+ "silo": {
665
+ "name": "Ore Silo",
666
+ "cost": 150,
667
+ "hp": 30000,
668
+ "power": -10,
669
+ "side": "both",
670
+ "prerequisites": ["proc"],
671
+ "produces": [],
672
+ "description": "Additional ore storage capacity.",
673
+ },
674
+ "kenn": {
675
+ "name": "Kennel",
676
+ "cost": 200,
677
+ "hp": 30000,
678
+ "power": -10,
679
+ "side": "soviet",
680
+ "prerequisites": ["powr"],
681
+ "produces": ["Infantry"],
682
+ "description": "Produces attack dogs. Soviet only.",
683
+ },
684
+
685
+ # Defenses
686
+ "pbox": {
687
+ "name": "Pillbox",
688
+ "cost": 600,
689
+ "hp": 40000,
690
+ "power": 0,
691
+ "side": "allied",
692
+ "prerequisites": ["tent"],
693
+ "produces": [],
694
+ "description": "Anti-infantry defense turret. Allied only.",
695
+ },
696
+ "hbox": {
697
+ "name": "Camo Pillbox",
698
+ "cost": 750,
699
+ "hp": 40000,
700
+ "power": 0,
701
+ "side": "allied",
702
+ "prerequisites": ["tent"],
703
+ "produces": [],
704
+ "description": "Hidden anti-infantry defense. Allied only.",
705
+ },
706
+ "gun": {
707
+ "name": "Turret",
708
+ "cost": 800,
709
+ "hp": 40000,
710
+ "power": -20,
711
+ "side": "allied",
712
+ "prerequisites": ["weap"],
713
+ "produces": [],
714
+ "description": "Anti-armor defense turret. Allied only.",
715
+ },
716
+ "ftur": {
717
+ "name": "Flame Tower",
718
+ "cost": 600,
719
+ "hp": 40000,
720
+ "power": -20,
721
+ "side": "soviet",
722
+ "prerequisites": ["barr"],
723
+ "produces": [],
724
+ "description": "Short-range anti-infantry defense. Soviet only.",
725
+ },
726
+ "tsla": {
727
+ "name": "Tesla Coil",
728
+ "cost": 1200,
729
+ "hp": 40000,
730
+ "power": -75,
731
+ "side": "soviet",
732
+ "prerequisites": ["weap"],
733
+ "produces": [],
734
+ "description": "Powerful anti-ground defense. High power cost. Soviet only.",
735
+ },
736
+ "agun": {
737
+ "name": "AA Gun",
738
+ "cost": 800,
739
+ "hp": 40000,
740
+ "power": -50,
741
+ "side": "allied",
742
+ "prerequisites": ["dome"],
743
+ "produces": [],
744
+ "description": "Anti-air defense turret. Allied only.",
745
+ },
746
+ "sam": {
747
+ "name": "SAM Site",
748
+ "cost": 700,
749
+ "hp": 40000,
750
+ "power": -20,
751
+ "side": "soviet",
752
+ "prerequisites": ["dome"],
753
+ "produces": [],
754
+ "description": "Anti-air missile defense. Soviet only.",
755
+ },
756
+ "gap": {
757
+ "name": "Gap Generator",
758
+ "cost": 800,
759
+ "hp": 50000,
760
+ "power": -60,
761
+ "side": "allied",
762
+ "prerequisites": ["atek"],
763
+ "produces": [],
764
+ "description": "Creates shroud area over your base. Allied only.",
765
+ },
766
+
767
+ # Superweapons
768
+ "iron": {
769
+ "name": "Iron Curtain",
770
+ "cost": 2000,
771
+ "hp": 100000,
772
+ "power": -200,
773
+ "side": "soviet",
774
+ "prerequisites": ["stek"],
775
+ "produces": [],
776
+ "build_limit": 1,
777
+ "description": "Superweapon: Makes one unit/building invulnerable temporarily.",
778
+ },
779
+ "pdox": {
780
+ "name": "Chronosphere",
781
+ "cost": 1500,
782
+ "hp": 100000,
783
+ "power": -200,
784
+ "side": "allied",
785
+ "prerequisites": ["atek"],
786
+ "produces": [],
787
+ "build_limit": 1,
788
+ "description": "Superweapon: Teleports units across the map.",
789
+ },
790
+ "mslo": {
791
+ "name": "Missile Silo",
792
+ "cost": 2500,
793
+ "hp": 100000,
794
+ "power": -150,
795
+ "side": "soviet",
796
+ "prerequisites": ["stek"],
797
+ "produces": [],
798
+ "build_limit": 1,
799
+ "description": "Superweapon: Launches nuclear missile at target location.",
800
+ },
801
+ }
802
+
803
+
804
+ # ─── Tech Tree ────────────────────────────────────────────────────────────────
805
+
806
+ RA_TECH_TREE: dict[str, list[str]] = {
807
+ "soviet": [
808
+ "powr", # Power Plant (base)
809
+ "barr", # Barracks → infantry (requires powr)
810
+ "kenn", # Kennel → dogs (requires powr)
811
+ "proc", # Ore Refinery (requires powr)
812
+ "weap", # War Factory (requires proc)
813
+ "spen", # Sub Pen (requires powr, needs water)
814
+ "dome", # Radar Dome (requires proc)
815
+ "fix", # Service Depot (requires weap)
816
+ "afld", # Airfield (requires dome)
817
+ "stek", # Tech Center (requires dome + weap)
818
+ "tsla", # Tesla Coil (requires weap)
819
+ "sam", # SAM Site (requires dome)
820
+ "ftur", # Flame Tower (requires barr)
821
+ "iron", # Iron Curtain (requires stek)
822
+ "mslo", # Missile Silo (requires stek)
823
+ ],
824
+ "allied": [
825
+ "powr", # Power Plant (base)
826
+ "tent", # Barracks → infantry (requires powr)
827
+ "proc", # Ore Refinery (requires powr)
828
+ "weap", # War Factory (requires proc)
829
+ "syrd", # Naval Yard (requires powr, needs water)
830
+ "dome", # Radar Dome (requires proc)
831
+ "fix", # Service Depot (requires weap)
832
+ "hpad", # Helipad (requires dome)
833
+ "atek", # Tech Center (requires dome + weap)
834
+ "gun", # Turret (requires weap)
835
+ "pbox", # Pillbox (requires tent)
836
+ "agun", # AA Gun (requires dome)
837
+ "gap", # Gap Generator (requires atek)
838
+ "pdox", # Chronosphere (requires atek)
839
+ ],
840
+ }
841
+
842
+
843
+ # ─── Faction Data ─────────────────────────────────────────────────────────────
844
+
845
+ RA_FACTIONS: dict[str, dict] = {
846
+ "england": {
847
+ "side": "allied",
848
+ "display_name": "England",
849
+ "unique_units": [],
850
+ "description": "Standard Allied faction.",
851
+ },
852
+ "france": {
853
+ "side": "allied",
854
+ "display_name": "France",
855
+ "unique_units": ["stnk"],
856
+ "description": "Allied faction with Phase Transport (cloaked APC).",
857
+ },
858
+ "germany": {
859
+ "side": "allied",
860
+ "display_name": "Germany",
861
+ "unique_units": ["ctnk"],
862
+ "description": "Allied faction with Chrono Tank (teleporting tank).",
863
+ },
864
+ "russia": {
865
+ "side": "soviet",
866
+ "display_name": "Russia",
867
+ "unique_units": ["ttnk"],
868
+ "description": "Soviet faction with Tesla Tank.",
869
+ },
870
+ "ukraine": {
871
+ "side": "soviet",
872
+ "display_name": "Ukraine",
873
+ "unique_units": ["dtrk"],
874
+ "description": "Soviet faction with Demolition Truck (nuclear suicide vehicle).",
875
+ },
876
+ }
877
+
878
+
879
+ # ─── Query Functions ──────────────────────────────────────────────────────────
880
+
881
+
882
+ def get_unit_stats(unit_type: str) -> Optional[dict]:
883
+ """Get stats for a unit type. Returns None if not found."""
884
+ return RA_UNITS.get(unit_type.lower())
885
+
886
+
887
+ def get_building_stats(building_type: str) -> Optional[dict]:
888
+ """Get stats for a building type. Returns None if not found."""
889
+ return RA_BUILDINGS.get(building_type.lower())
890
+
891
+
892
+ def get_tech_tree(faction: Optional[str] = None) -> dict:
893
+ """Get the tech tree build order.
894
+
895
+ Args:
896
+ faction: Faction name (e.g., 'russia') or side ('allied', 'soviet').
897
+ If None, returns both sides.
898
+ """
899
+ if faction is None:
900
+ return RA_TECH_TREE
901
+
902
+ # Map faction to side
903
+ side = faction.lower()
904
+ if side in RA_FACTIONS:
905
+ side = RA_FACTIONS[side]["side"]
906
+
907
+ if side in RA_TECH_TREE:
908
+ return {side: RA_TECH_TREE[side]}
909
+
910
+ return {}
911
+
912
+
913
+ def get_faction_info(faction: str) -> Optional[dict]:
914
+ """Get faction info including available units and buildings."""
915
+ faction = faction.lower()
916
+ info = RA_FACTIONS.get(faction)
917
+ if info is None:
918
+ return None
919
+
920
+ side = info["side"]
921
+
922
+ # Collect units available to this faction
923
+ available_units = []
924
+ for unit_type, data in RA_UNITS.items():
925
+ unit_side = data.get("side", "")
926
+ if unit_side == "both" or unit_side == side:
927
+ available_units.append(unit_type)
928
+
929
+ # Add faction-unique units
930
+ for u in info.get("unique_units", []):
931
+ if u not in available_units and u in RA_UNITS:
932
+ available_units.append(u)
933
+
934
+ # Collect buildings
935
+ available_buildings = []
936
+ for bldg_type, data in RA_BUILDINGS.items():
937
+ bldg_side = data.get("side", "")
938
+ if bldg_side == "both" or bldg_side == side:
939
+ available_buildings.append(bldg_type)
940
+
941
+ return {
942
+ **info,
943
+ "faction": faction,
944
+ "available_units": sorted(available_units),
945
+ "available_buildings": sorted(available_buildings),
946
+ }
947
+
948
+
949
+ def get_all_unit_types() -> list[str]:
950
+ """Get all available unit type names."""
951
+ return sorted(RA_UNITS.keys())
952
+
953
+
954
+ def get_all_building_types() -> list[str]:
955
+ """Get all available building type names."""
956
+ return sorted(RA_BUILDINGS.keys())
957
+
958
+
959
+ def get_all_units_for_side(side: str) -> dict[str, dict]:
960
+ """Get all units available to a side ('allied' or 'soviet') with full stats.
961
+
962
+ Returns dict keyed by unit type name, each value is the full stats dict.
963
+ Includes units with side='both' plus units specific to the given side.
964
+ """
965
+ side = side.lower()
966
+ return {
967
+ utype: dict(data)
968
+ for utype, data in RA_UNITS.items()
969
+ if data.get("side") in (side, "both")
970
+ }
971
+
972
+
973
+ def get_all_buildings_for_side(side: str) -> dict[str, dict]:
974
+ """Get all buildings available to a side ('allied' or 'soviet') with full stats.
975
+
976
+ Returns dict keyed by building type name, each value is the full stats dict.
977
+ Includes buildings with side='both' plus buildings specific to the given side.
978
+ """
979
+ side = side.lower()
980
+ return {
981
+ btype: dict(data)
982
+ for btype, data in RA_BUILDINGS.items()
983
+ if data.get("side") in (side, "both")
984
+ }
openra_rl_training/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Vendored subset of OpenRA-RL-Training — see VENDOR.md.
2
+
3
+ Faithful, frozen copies of exactly the modules OpenRA-Bench needs from
4
+ the `openra_rl_training` package, so the bench runs without an external
5
+ OpenRA-RL-Training checkout. Do NOT hand-edit the vendored modules —
6
+ re-vendor from source (and re-run the full suite) if they must change.
7
+ """
8
+
9
+ __version__ = "0.1.0-vendored"
openra_rl_training/scenario.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for scenario and curriculum YAML definitions.
2
+
3
+ Scenarios define custom starting conditions for RL training episodes:
4
+ units, positions, stances, factions, and termination conditions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Literal, Optional, Union
12
+
13
+ import yaml
14
+ from openra_env.game_data import RA_BUILDINGS, RA_UNITS
15
+ from pydantic import BaseModel, Field, field_validator, model_validator
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # All valid actor types that can be placed on maps
20
+ VALID_ACTOR_TYPES = set(RA_UNITS.keys()) | set(RA_BUILDINGS.keys())
21
+
22
+ # Unit stances matching OpenRA's UnitStance enum
23
+ STANCE_HOLD_FIRE = 0
24
+ STANCE_RETURN_FIRE = 1
25
+ STANCE_DEFEND = 2
26
+ STANCE_ATTACK_ANYTHING = 3
27
+
28
+ STANCE_NAMES = {
29
+ STANCE_HOLD_FIRE: "HoldFire",
30
+ STANCE_RETURN_FIRE: "ReturnFire",
31
+ STANCE_DEFEND: "Defend",
32
+ STANCE_ATTACK_ANYTHING: "AttackAnything",
33
+ }
34
+
35
+
36
+ # ── Randomization models ─────────────────────────────────────────────────────
37
+
38
+
39
+ class TypeFilter(BaseModel):
40
+ """Filter-based type randomization: pick a random unit matching criteria."""
41
+
42
+ category: str = Field(..., description="Unit category: infantry, vehicle, aircraft, ship")
43
+ side: str = Field(default="both", description="Faction filter: allied, soviet, both")
44
+ max_cost: Optional[int] = Field(default=None, description="Maximum unit cost")
45
+ min_cost: Optional[int] = Field(default=None, description="Minimum unit cost")
46
+ armor: Optional[str] = Field(default=None, description="Armor type: none, light, heavy")
47
+
48
+ @field_validator("category")
49
+ @classmethod
50
+ def validate_category(cls, v: str) -> str:
51
+ v = v.lower()
52
+ valid = {"infantry", "vehicle", "aircraft", "ship"}
53
+ if v not in valid:
54
+ raise ValueError(f"category must be one of {sorted(valid)}, got '{v}'")
55
+ return v
56
+
57
+ @field_validator("side")
58
+ @classmethod
59
+ def validate_side(cls, v: str) -> str:
60
+ v = v.lower()
61
+ valid = {"allied", "soviet", "both"}
62
+ if v not in valid:
63
+ raise ValueError(f"side must be one of {sorted(valid)}, got '{v}'")
64
+ return v
65
+
66
+ @field_validator("armor")
67
+ @classmethod
68
+ def validate_armor(cls, v: Optional[str]) -> Optional[str]:
69
+ if v is not None:
70
+ v = v.lower()
71
+ valid = {"none", "light", "heavy"}
72
+ if v not in valid:
73
+ raise ValueError(f"armor must be one of {sorted(valid)}, got '{v}'")
74
+ return v
75
+
76
+ def matching_types(self) -> list[str]:
77
+ """Return all RA_UNITS keys matching this filter."""
78
+ results = []
79
+ for utype, data in RA_UNITS.items():
80
+ if data.get("category") != self.category:
81
+ continue
82
+ unit_side = data.get("side", "both")
83
+ if self.side != "both" and unit_side not in (self.side, "both"):
84
+ continue
85
+ cost = data.get("cost", 0)
86
+ if self.max_cost is not None and cost > self.max_cost:
87
+ continue
88
+ if self.min_cost is not None and cost < self.min_cost:
89
+ continue
90
+ if self.armor is not None and data.get("armor") != self.armor:
91
+ continue
92
+ results.append(utype)
93
+ return sorted(results)
94
+
95
+
96
+ class PositionOffset(BaseModel):
97
+ """Offset-based position randomization: random within ±offset of base."""
98
+
99
+ base: tuple[int, int] = Field(..., description="Base position [x, y]")
100
+ offset: int = Field(..., description="Max offset in cells (applies to both x and y)")
101
+
102
+ @field_validator("offset")
103
+ @classmethod
104
+ def validate_offset(cls, v: int) -> int:
105
+ if v < 1 or v > 50:
106
+ raise ValueError(f"offset must be 1-50, got {v}")
107
+ return v
108
+
109
+
110
+ class HealthRange(BaseModel):
111
+ """Range-based health randomization."""
112
+
113
+ min: int = Field(default=1, description="Minimum health percentage")
114
+ max: int = Field(default=100, description="Maximum health percentage")
115
+
116
+ @model_validator(mode="after")
117
+ def validate_range(self) -> "HealthRange":
118
+ if self.min < 1 or self.max > 100:
119
+ raise ValueError(f"Health range must be 1-100, got {self.min}-{self.max}")
120
+ if self.min > self.max:
121
+ raise ValueError(f"min ({self.min}) must be <= max ({self.max})")
122
+ return self
123
+
124
+
125
+ class ActorRandomization(BaseModel):
126
+ """Per-field randomization options for an actor placement."""
127
+
128
+ type: Optional[Union[list[str], TypeFilter]] = Field(
129
+ default=None, description="Type alternatives: list of names or category filter"
130
+ )
131
+ position: Optional[Union[list[tuple[int, int]], PositionOffset]] = Field(
132
+ default=None, description="Position alternatives: preset list or offset from base"
133
+ )
134
+ stance: Optional[list[int]] = Field(default=None, description="Stance alternatives (0-3)")
135
+ health: Optional[HealthRange] = Field(default=None, description="Health range {min, max}")
136
+ facing: Optional[list[int]] = Field(default=None, description="Facing alternatives (0-1023)")
137
+
138
+ @field_validator("type")
139
+ @classmethod
140
+ def validate_type_alternatives(cls, v: Optional[Union[list[str], TypeFilter]]):
141
+ if isinstance(v, list):
142
+ if not v:
143
+ raise ValueError("type list must not be empty")
144
+ for t in v:
145
+ if t.lower() not in VALID_ACTOR_TYPES:
146
+ raise ValueError(f"Unknown actor type in randomize.type: '{t}'")
147
+ return v
148
+
149
+ @field_validator("position")
150
+ @classmethod
151
+ def validate_position_alternatives(
152
+ cls, v: Optional[Union[list[tuple[int, int]], PositionOffset]]
153
+ ):
154
+ if isinstance(v, list) and not v:
155
+ raise ValueError("position list must not be empty")
156
+ return v
157
+
158
+ @field_validator("stance")
159
+ @classmethod
160
+ def validate_stance_alternatives(cls, v: Optional[list[int]]):
161
+ if v is not None:
162
+ if not v:
163
+ raise ValueError("stance list must not be empty")
164
+ for s in v:
165
+ if s < 0 or s > 3:
166
+ raise ValueError(f"Stance must be 0-3, got {s}")
167
+ return v
168
+
169
+ @field_validator("facing")
170
+ @classmethod
171
+ def validate_facing_alternatives(cls, v: Optional[list[int]]):
172
+ if v is not None:
173
+ if not v:
174
+ raise ValueError("facing list must not be empty")
175
+ for f in v:
176
+ if f < 0 or f > 1023:
177
+ raise ValueError(f"Facing must be 0-1023, got {f}")
178
+ return v
179
+
180
+
181
+ # ── Core scenario models ─────────────────────────────────────────────────────
182
+
183
+
184
+ class ActorPlacement(BaseModel):
185
+ """A unit or building to spawn at game start."""
186
+
187
+ type: str = Field(..., description="Actor type (e.g., '2tnk', 'e1', 'fact')")
188
+ owner: Literal["agent", "enemy", "neutral"] = Field(
189
+ default="agent", description="Which player owns this actor"
190
+ )
191
+ position: tuple[int, int] = Field(..., description="Cell coordinates [x, y]")
192
+ stance: int = Field(
193
+ default=STANCE_ATTACK_ANYTHING,
194
+ description="0=HoldFire, 1=ReturnFire, 2=Defend, 3=AttackAnything",
195
+ )
196
+ health: int = Field(default=100, description="HP percentage 1-100")
197
+ facing: int = Field(default=-1, description="-1=auto, 0-1023 WAngle")
198
+ count: int = Field(default=1, description="Spawn N copies with auto-offset positions")
199
+ spawn_point: Optional[int] = Field(
200
+ default=None,
201
+ description="Spawn point group (0-N). If set, only included when this spawn point is selected. "
202
+ "None = always included (enemies, neutral).",
203
+ )
204
+ randomize: Optional[ActorRandomization] = Field(
205
+ default=None,
206
+ description="Per-field randomization options (resolved before map generation)",
207
+ )
208
+
209
+ @field_validator("type")
210
+ @classmethod
211
+ def validate_type(cls, v: str) -> str:
212
+ v = v.lower()
213
+ if v not in VALID_ACTOR_TYPES:
214
+ raise ValueError(
215
+ f"Unknown actor type '{v}'. "
216
+ f"Valid units: {sorted(RA_UNITS.keys())[:10]}... "
217
+ f"Valid buildings: {sorted(RA_BUILDINGS.keys())[:10]}..."
218
+ )
219
+ return v
220
+
221
+ @field_validator("stance")
222
+ @classmethod
223
+ def validate_stance(cls, v: int) -> int:
224
+ if v < 0 or v > 3:
225
+ raise ValueError(f"Stance must be 0-3, got {v}")
226
+ return v
227
+
228
+ @field_validator("health")
229
+ @classmethod
230
+ def validate_health(cls, v: int) -> int:
231
+ if v < 1 or v > 100:
232
+ raise ValueError(f"Health must be 1-100, got {v}")
233
+ return v
234
+
235
+ @field_validator("facing")
236
+ @classmethod
237
+ def validate_facing(cls, v: int) -> int:
238
+ if v != -1 and (v < 0 or v > 1023):
239
+ raise ValueError(f"Facing must be -1 (auto) or 0-1023, got {v}")
240
+ return v
241
+
242
+ @field_validator("count")
243
+ @classmethod
244
+ def validate_count(cls, v: int) -> int:
245
+ if v < 1 or v > 50:
246
+ raise ValueError(f"Count must be 1-50, got {v}")
247
+ return v
248
+
249
+ @property
250
+ def is_building(self) -> bool:
251
+ return self.type in RA_BUILDINGS
252
+
253
+
254
+ class PlayerSetup(BaseModel):
255
+ """Configuration for the agent player."""
256
+
257
+ faction: Literal["allies", "soviet", "random"] = Field(
258
+ default="random", description="Player faction"
259
+ )
260
+ cash: int = Field(default=0, description="Starting cash override")
261
+
262
+ @field_validator("cash")
263
+ @classmethod
264
+ def validate_cash(cls, v: int) -> int:
265
+ if v < 0:
266
+ raise ValueError(f"Cash must be non-negative, got {v}")
267
+ return v
268
+
269
+
270
+ class EnemySetup(PlayerSetup):
271
+ """Configuration for the enemy player."""
272
+
273
+ bot_type: str = Field(
274
+ default="", description="AI bot type (empty = no AI, stance-only behavior)"
275
+ )
276
+
277
+
278
+ class TerminationConfig(BaseModel):
279
+ """When to end a scenario episode."""
280
+
281
+ max_ticks: int = Field(default=5000, description="Tick limit (0 = unlimited)")
282
+ max_time: Optional[float] = Field(
283
+ default=None,
284
+ description="Time limit in seconds (overrides max_ticks). 25 ticks = 1 second.",
285
+ )
286
+ agent_units_killed: bool = Field(
287
+ default=True, description="End as 'lose' when all agent units destroyed"
288
+ )
289
+ enemy_units_killed: bool = Field(
290
+ default=True, description="End as 'win' when all enemy units/buildings destroyed"
291
+ )
292
+
293
+ @field_validator("max_ticks")
294
+ @classmethod
295
+ def validate_max_ticks(cls, v: int) -> int:
296
+ if v < 0:
297
+ raise ValueError(f"max_ticks must be non-negative, got {v}")
298
+ return v
299
+
300
+ @model_validator(mode="after")
301
+ def resolve_max_time(self) -> "TerminationConfig":
302
+ """Convert max_time (seconds) to max_ticks if specified."""
303
+ if self.max_time is not None:
304
+ self.max_ticks = int(self.max_time * 25)
305
+ return self
306
+
307
+
308
+ class ScenarioDefinition(BaseModel):
309
+ """Complete scenario definition loaded from YAML."""
310
+
311
+ name: str = Field(..., description="Scenario display name")
312
+ description: str = Field(default="", description="Human-readable description")
313
+ base_map: str = Field(default="singles.oramap", description="Base map filename for terrain")
314
+ agent: PlayerSetup = Field(default_factory=PlayerSetup)
315
+ enemy: EnemySetup = Field(default_factory=EnemySetup)
316
+ actors: list[ActorPlacement] = Field(..., description="Units/buildings to spawn")
317
+ termination: TerminationConfig = Field(default_factory=TerminationConfig)
318
+ reward: dict[str, float] = Field(default_factory=dict, description="Override reward weights")
319
+ reward_calibration: dict[str, float] = Field(
320
+ default_factory=dict,
321
+ description="Manual overrides for reward calibration constants (auto-computed if empty)",
322
+ )
323
+ tools: list[str] = Field(default_factory=list, description="Allowed tool names (empty = all)")
324
+ interrupts: dict[str, bool] = Field(
325
+ default_factory=dict,
326
+ description="Override interrupt signals: signal_name → enabled/disabled. All enabled by default.",
327
+ )
328
+ planning: bool = Field(default=False, description="Enable pre-game planning phase")
329
+ difficulty: int = Field(default=1, description="Difficulty level for ordering")
330
+ tags: list[str] = Field(default_factory=list, description="Tags for filtering")
331
+
332
+ @field_validator("tools")
333
+ @classmethod
334
+ def strip_internal_tools(cls, v: list[str]) -> list[str]:
335
+ """Remove internal-only tools that the LLM should never call directly."""
336
+ _INTERNAL_TOOLS = {"get_game_state", "surrender"}
337
+ return [t for t in v if t not in _INTERNAL_TOOLS]
338
+
339
+ @field_validator("actors")
340
+ @classmethod
341
+ def validate_actors_not_empty(cls, v: list[ActorPlacement]) -> list[ActorPlacement]:
342
+ if not v:
343
+ raise ValueError("Scenario must have at least one actor")
344
+ return v
345
+
346
+ @model_validator(mode="after")
347
+ def validate_has_agent_actor(self) -> "ScenarioDefinition":
348
+ agent_actors = [a for a in self.actors if a.owner == "agent"]
349
+ if not agent_actors:
350
+ raise ValueError("Scenario must have at least one agent-owned actor")
351
+ return self
352
+
353
+ @property
354
+ def agent_actors(self) -> list[ActorPlacement]:
355
+ return [a for a in self.actors if a.owner == "agent"]
356
+
357
+ @property
358
+ def enemy_actors(self) -> list[ActorPlacement]:
359
+ return [a for a in self.actors if a.owner == "enemy"]
360
+
361
+ @property
362
+ def neutral_actors(self) -> list[ActorPlacement]:
363
+ return [a for a in self.actors if a.owner == "neutral"]
364
+
365
+
366
+ def load_scenario(path: str | Path) -> ScenarioDefinition:
367
+ """Load a scenario definition from a YAML file.
368
+
369
+ Args:
370
+ path: Path to the scenario YAML file.
371
+
372
+ Returns:
373
+ Parsed and validated ScenarioDefinition.
374
+ """
375
+ path = Path(path)
376
+ if not path.exists():
377
+ raise FileNotFoundError(f"Scenario file not found: {path}")
378
+
379
+ with open(path) as f:
380
+ data = yaml.safe_load(f)
381
+
382
+ if data is None:
383
+ raise ValueError(f"Empty scenario file: {path}")
384
+
385
+ logger.info("Loading scenario '%s' from %s", data.get("name", "?"), path)
386
+ return ScenarioDefinition.model_validate(data)
387
+
388
+
389
+ def load_scenario_from_string(yaml_string: str) -> ScenarioDefinition:
390
+ """Load a scenario definition from a YAML string.
391
+
392
+ Args:
393
+ yaml_string: YAML content.
394
+
395
+ Returns:
396
+ Parsed and validated ScenarioDefinition.
397
+ """
398
+ data = yaml.safe_load(yaml_string)
399
+ if data is None:
400
+ raise ValueError("Empty scenario YAML")
401
+ return ScenarioDefinition.model_validate(data)
openra_rl_training/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Vendored subset of openra_rl_training.training — see VENDOR.md."""
openra_rl_training/training/minimap_renderer.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Render a vision minimap for the planning phase.
2
+
3
+ Produces a small PNG image (~448x222, ~96 vision tokens) showing:
4
+ - Actual terrain from the base map (map.png)
5
+ - Visibility layers: visible (bright), fog of war (dimmed), unexplored (dark)
6
+ - Own units (cyan circles), enemy units (red circles), enemy buildings (red squares)
7
+ - Coordinate grid and compact legend
8
+
9
+ The image is returned as a base64-encoded PNG for injection into the
10
+ OpenAI-compatible vision API (SGLang/vLLM).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import base64
16
+ import io
17
+ import logging
18
+
19
+ import matplotlib
20
+ import numpy as np
21
+ from PIL import Image
22
+
23
+ matplotlib.use("Agg")
24
+ # Use Figure() OO API instead of pyplot — pyplot's global figure manager is
25
+ # NOT thread-safe, which prevents off-loading rendering from the event loop.
26
+ from matplotlib.figure import Figure
27
+ from matplotlib.lines import Line2D
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Buildings from OpenRA game data
32
+ BUILDINGS = {
33
+ "fact", "powr", "apwr", "tent", "barr", "proc", "weap", "dome",
34
+ "fix", "hpad", "afld", "spen", "syrd", "pbox", "hbox", "gun",
35
+ "ftur", "tsla", "agun", "sam", "gap", "iron", "mslo", "atek", "stek",
36
+ "kenn", "silo",
37
+ }
38
+
39
+ # Visibility brightness multipliers
40
+ VIS_BRIGHT = 1.0 # Currently visible (unit line of sight)
41
+ VIS_FOG = 0.40 # Previously explored, no current vision
42
+ VIS_UNEXPLORED = 0.08 # Never seen
43
+
44
+ # Unit vision radius in cells
45
+ VISION_RADIUS = 10
46
+
47
+ # Supersampling: render at 2x, downsample with LANCZOS
48
+ RENDER_SCALE = 2
49
+ TARGET_WIDTH = 448
50
+
51
+
52
+ def _blur_2d(arr: np.ndarray, sigma: float = 1.5, size: int = 7) -> np.ndarray:
53
+ """Simple separable gaussian blur without scipy dependency."""
54
+ x = np.arange(size) - size // 2
55
+ k = np.exp(-x**2 / (2 * sigma**2))
56
+ k /= k.sum()
57
+ r = np.apply_along_axis(lambda row: np.convolve(row, k, mode="same"), 1, arr)
58
+ return np.apply_along_axis(lambda col: np.convolve(col, k, mode="same"), 0, r)
59
+
60
+
61
+ def _parse_ascii_minimap(
62
+ ascii_minimap: str, map_width: int, map_height: int
63
+ ) -> np.ndarray:
64
+ """Parse ASCII minimap to get explored mask at full map resolution.
65
+
66
+ Characters: # = unexplored, everything else = explored.
67
+ The ASCII grid is downsampled by scale = ceil(map_width / 28).
68
+
69
+ Returns:
70
+ Boolean array (map_height, map_width) — True = explored.
71
+ """
72
+ lines = [l for l in ascii_minimap.strip().split("\n") if l.strip()]
73
+ # Skip header lines (e.g. "Map (28x14, 1cell=4x4):")
74
+ grid_lines = []
75
+ for line in lines:
76
+ stripped = line.strip()
77
+ if stripped and all(c in "#.@!X~$B " for c in stripped):
78
+ grid_lines.append(stripped)
79
+
80
+ if not grid_lines:
81
+ return np.zeros((map_height, map_width), dtype=bool)
82
+
83
+ grid_h = len(grid_lines)
84
+ grid_w = max(len(l) for l in grid_lines)
85
+ scale_x = max(1, map_width // grid_w) if grid_w > 0 else 1
86
+ scale_y = max(1, map_height // grid_h) if grid_h > 0 else 1
87
+
88
+ explored = np.zeros((map_height, map_width), dtype=bool)
89
+ for gy, line in enumerate(grid_lines):
90
+ for gx, ch in enumerate(line):
91
+ if ch != "#":
92
+ # Mark the corresponding map cells as explored
93
+ y0 = gy * scale_y
94
+ x0 = gx * scale_x
95
+ y1 = min(y0 + scale_y, map_height)
96
+ x1 = min(x0 + scale_x, map_width)
97
+ explored[y0:y1, x0:x1] = True
98
+
99
+ return explored
100
+
101
+
102
+ def _compute_visible_mask(
103
+ own_units: list[dict], map_width: int, map_height: int
104
+ ) -> np.ndarray:
105
+ """Compute currently visible cells from own unit positions."""
106
+ visible = np.zeros((map_height, map_width), dtype=bool)
107
+ r = VISION_RADIUS
108
+ for u in own_units:
109
+ cx = u.get("cell_x", 0)
110
+ cy = u.get("cell_y", 0)
111
+ y_lo = max(0, cy - r)
112
+ y_hi = min(map_height, cy + r + 1)
113
+ x_lo = max(0, cx - r)
114
+ x_hi = min(map_width, cx + r + 1)
115
+ for y in range(y_lo, y_hi):
116
+ for x in range(x_lo, x_hi):
117
+ if (x - cx) ** 2 + (y - cy) ** 2 <= r * r:
118
+ visible[y, x] = True
119
+ return visible
120
+
121
+
122
+ def render_minimap(
123
+ terrain_png: bytes,
124
+ map_width: int,
125
+ map_height: int,
126
+ bounds_x: int,
127
+ bounds_y: int,
128
+ own_units: list[dict],
129
+ enemy_units: list[dict],
130
+ ascii_minimap: str,
131
+ output_width: int = TARGET_WIDTH,
132
+ ) -> str | None:
133
+ """Render a vision minimap and return base64-encoded PNG.
134
+
135
+ Args:
136
+ terrain_png: Raw bytes of map.png from the .oramap file.
137
+ map_width: Full map width in cells.
138
+ map_height: Full map height in cells.
139
+ bounds_x: Playable area X offset.
140
+ bounds_y: Playable area Y offset.
141
+ own_units: List of own unit dicts with cell_x, cell_y, type.
142
+ enemy_units: List of visible enemy unit dicts with cell_x, cell_y, type.
143
+ ascii_minimap: ASCII minimap string from game state.
144
+ output_width: Target image width in pixels.
145
+
146
+ Returns:
147
+ Base64-encoded PNG string, or None on failure.
148
+ """
149
+ try:
150
+ # Load terrain
151
+ terrain_img = Image.open(io.BytesIO(terrain_png)).convert("RGB")
152
+ pw, ph = terrain_img.size # terrain image pixel dimensions
153
+ terrain_arr = np.array(terrain_img).astype(float) / 255.0
154
+
155
+ # Compute visibility masks in cell coordinates
156
+ explored = _parse_ascii_minimap(ascii_minimap, map_width, map_height)
157
+ visible = _compute_visible_mask(own_units, map_width, map_height)
158
+ explored |= visible
159
+
160
+ # Use full map (including borders) so terrain boundaries are visible
161
+ playable_w = min(map_width - bounds_x, map_width)
162
+ playable_h = min(map_height - bounds_y, map_height)
163
+ explored_full = explored[:ph, :pw] if explored.shape[0] >= ph and explored.shape[1] >= pw else explored
164
+ visible_full = visible[:ph, :pw] if visible.shape[0] >= ph and visible.shape[1] >= pw else visible
165
+
166
+ # Resize visibility masks to match terrain image pixel dimensions
167
+ if explored_full.shape != (ph, pw):
168
+ explored_full = np.array(Image.fromarray(explored_full).resize((pw, ph), Image.NEAREST))
169
+ visible_full = np.array(Image.fromarray(visible_full).resize((pw, ph), Image.NEAREST))
170
+
171
+ # Smooth edges
172
+ explored_s = np.clip(_blur_2d(explored_full.astype(float), sigma=1.5, size=7), 0, 1)
173
+ visible_s = np.clip(_blur_2d(visible_full.astype(float), sigma=1.5, size=7), 0, 1)
174
+
175
+ # Composite terrain with visibility (vectorized)
176
+ brightness = VIS_UNEXPLORED * (1 - explored_s) + VIS_FOG * explored_s
177
+ brightness = brightness * (1 - visible_s) + VIS_BRIGHT * visible_s
178
+ # Ensure terrain borders (water/cliffs) are always visible — detect by
179
+ # checking if the terrain pixel is distinctly different from grass.
180
+ # Water/cliff pixels are blue-ish (high B, low G), grass is green-ish.
181
+ _is_water = terrain_arr[..., 2] > terrain_arr[..., 1] # blue > green
182
+ brightness = np.where(_is_water, np.maximum(brightness, VIS_FOG), brightness)
183
+ composite = terrain_arr * brightness[..., np.newaxis]
184
+
185
+ # Render with matplotlib at 2x for supersampling
186
+ render_dpi = 192 * RENDER_SCALE
187
+ fig_w = 3.5
188
+ fig_h = fig_w * ph / pw # maintain aspect ratio
189
+ # OO API (thread-safe, no global figure manager)
190
+ fig = Figure(figsize=(fig_w, fig_h), dpi=render_dpi)
191
+ ax = fig.add_subplot(1, 1, 1)
192
+ bg = "#0a0a0f"
193
+ fig.patch.set_facecolor(bg)
194
+ ax.set_facecolor(bg)
195
+
196
+ ax.imshow(
197
+ composite,
198
+ extent=[0, pw, ph, 0],
199
+ interpolation="bilinear",
200
+ aspect="auto",
201
+ )
202
+
203
+ # Grid
204
+ for x in range(0, map_width + 1, 20):
205
+ ax.axvline(x, color="white", alpha=0.15, linewidth=0.4)
206
+ for y in range(0, map_height + 1, 10):
207
+ ax.axhline(y, color="white", alpha=0.15, linewidth=0.4)
208
+
209
+ # Plot own units — cyan circles with glow.
210
+ # Halted-unreachable units (unit halted on a bad target — pathfinding
211
+ # failed repeatedly) get a YELLOW X overlay so the model can spot
212
+ # them clearly on the minimap and understand they need a new target.
213
+ for u in own_units:
214
+ ux, uy = u.get("cell_x", 0), u.get("cell_y", 0)
215
+ is_halted = bool(u.get("halted_unreachable"))
216
+ ax.plot(ux, uy, "o", color="#00b8d4", markersize=8, alpha=0.3, zorder=9)
217
+ ax.plot(
218
+ ux, uy, "o", color="#00e5ff", markersize=5,
219
+ markeredgecolor="white", markeredgewidth=0.6, zorder=10,
220
+ )
221
+ if is_halted:
222
+ # Yellow X overlay marking the unit as halted/unreachable.
223
+ ax.plot(
224
+ ux, uy, marker="x", color="#ffe600", markersize=8,
225
+ markeredgewidth=1.5, zorder=11,
226
+ )
227
+
228
+ # Plot enemy units — red circles/squares with glow
229
+ for u in enemy_units:
230
+ ux, uy = u.get("cell_x", 0), u.get("cell_y", 0)
231
+ utype = u.get("type", "").lower()
232
+ is_bldg = utype in BUILDINGS
233
+ marker = "s" if is_bldg else "o"
234
+ ms = 6 if is_bldg else 5
235
+ ax.plot(
236
+ ux, uy, marker, color="#ff1744", markersize=ms + 3,
237
+ alpha=0.25, zorder=9,
238
+ )
239
+ ax.plot(
240
+ ux, uy, marker, color="#ff1744", markersize=ms,
241
+ markeredgecolor="white", markeredgewidth=0.5, zorder=10,
242
+ )
243
+
244
+ # Show the FULL map including water/cliff borders — not just playable area.
245
+ # This lets the model see terrain boundaries clearly.
246
+ _x_max = map_width
247
+ _y_max = map_height
248
+ ax.set_xlim(0, _x_max)
249
+ ax.set_ylim(_y_max, 0)
250
+ # Ticks: evenly spaced within playable area + boundary values
251
+ _xticks = [x for x in range(0, _x_max + 1, 20) if x <= _x_max]
252
+ if _xticks[-1] != _x_max:
253
+ _xticks.append(_x_max)
254
+ _yticks = [y for y in range(0, _y_max + 1, 10) if y <= _y_max]
255
+ if _yticks[-1] != _y_max:
256
+ _yticks.append(_y_max)
257
+ ax.set_xticks(_xticks)
258
+ ax.set_yticks(_yticks)
259
+ ax.tick_params(
260
+ axis="both", colors="#8899aa", labelsize=6,
261
+ length=2, width=0.4, pad=1,
262
+ )
263
+ for spine in ax.spines.values():
264
+ spine.set_color("#2a3a50")
265
+ spine.set_linewidth(0.5)
266
+
267
+ # Compact legend — units + terrain
268
+ legend_elements = [
269
+ Line2D(
270
+ [0], [0], marker="o", color="w", markerfacecolor="#00e5ff",
271
+ markersize=5, label="Own", linestyle="None",
272
+ ),
273
+ Line2D(
274
+ [0], [0], marker="o", color="w", markerfacecolor="#ff1744",
275
+ markersize=5, label="Enemy", linestyle="None",
276
+ ),
277
+ Line2D(
278
+ [0], [0], marker="x", color="#ffe600", markersize=5,
279
+ label="Halted", linestyle="None", markeredgewidth=1.5,
280
+ ),
281
+ Line2D(
282
+ [0], [0], marker="s", color="w", markerfacecolor="#50a03c",
283
+ markersize=5, label="Land", linestyle="None",
284
+ ),
285
+ Line2D(
286
+ [0], [0], marker="s", color="w", markerfacecolor="#1e3c78",
287
+ markersize=5, label="Water", linestyle="None",
288
+ ),
289
+ Line2D(
290
+ [0], [0], marker="s", color="w", markerfacecolor="#6b5b3a",
291
+ markersize=5, label="Cliff", linestyle="None",
292
+ ),
293
+ ]
294
+ ax.legend(
295
+ handles=legend_elements, loc="upper right", fontsize=5,
296
+ framealpha=0.85, facecolor="#0a0a0f", edgecolor="#2a3a50",
297
+ labelcolor="#ccddee", handletextpad=0.3, borderpad=0.3,
298
+ columnspacing=0.6, ncol=6,
299
+ )
300
+
301
+ fig.tight_layout(pad=0.3)
302
+
303
+ # Render to buffer
304
+ buf = io.BytesIO()
305
+ fig.savefig(
306
+ buf, dpi=render_dpi, bbox_inches="tight",
307
+ facecolor=bg, pad_inches=0.03, format="png",
308
+ )
309
+ # No plt.close needed — Figure is local, no global state to release
310
+
311
+ # LANCZOS downsample to target size
312
+ buf.seek(0)
313
+ hi_res = Image.open(buf)
314
+ scale = output_width / hi_res.width
315
+ target_h = int(hi_res.height * scale)
316
+ final = hi_res.resize((output_width, target_h), Image.LANCZOS)
317
+
318
+ # Encode as base64
319
+ out_buf = io.BytesIO()
320
+ final.save(out_buf, format="PNG", optimize=True)
321
+ b64 = base64.b64encode(out_buf.getvalue()).decode("ascii")
322
+
323
+ logger.info(
324
+ "Rendered minimap: %dx%d, %d bytes, ~%d vision tokens",
325
+ final.width, final.height,
326
+ len(out_buf.getvalue()),
327
+ (final.width * final.height) // (32 * 32),
328
+ )
329
+ return b64
330
+
331
+ except Exception as e:
332
+ logger.warning("Minimap render failed: %s", e)
333
+ return None
openra_rl_training/training/reward_funcs.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO reward functions for OpenRA training.
2
+
3
+ Each function receives completions (list[str]) and extra kwargs from rollout_func.
4
+ Returns list[float] rewards, one per completion.
5
+
6
+ Per-scenario weighting: When ``scenario_weights`` is present in kwargs
7
+ (a list of dicts, one per completion), each function multiplies its base
8
+ reward by the scenario-specific weight for its signal. This lets combat-
9
+ focused scenarios boost combat reward while economy scenarios boost economy.
10
+ """
11
+
12
+ from collections import defaultdict
13
+
14
+ DEFAULT_REWARD_WEIGHTS: dict[str, float] = {
15
+ "outcome": 0.50,
16
+ "combat": 0.15,
17
+ "economy": 0.10,
18
+ "tempo": 0.10,
19
+ "density": 0.00,
20
+ "format": 0.05,
21
+ "survival": 0.10,
22
+ "discovery": 0.00,
23
+ "disruption": 0.00,
24
+ "exploration": 0.00,
25
+ }
26
+
27
+
28
+ def _apply_weights(
29
+ base: list[float], key: str, scenario_weights: list[dict] | None,
30
+ ) -> list[float]:
31
+ """Multiply base rewards by per-scenario weight for *key*."""
32
+ default_w = DEFAULT_REWARD_WEIGHTS[key]
33
+ if not scenario_weights:
34
+ return [r * default_w for r in base]
35
+ return [
36
+ r * (scenario_weights[i].get(key, default_w) if i < len(scenario_weights) else default_w)
37
+ for i, r in enumerate(base)
38
+ ]
39
+
40
+
41
+ def _normalize_within_group(rewards: list[float], spawn_groups: list[int]) -> list[float]:
42
+ """Center rewards within each spawn group (Fix A).
43
+
44
+ Within each group (same map), subtract the group mean so that GRPO
45
+ advantages measure behavioral differences, not spawn luck.
46
+ Groups with only 1 episode are left unchanged.
47
+ """
48
+ if not spawn_groups or len(spawn_groups) != len(rewards):
49
+ return rewards
50
+ groups: dict[int, list[tuple[int, float]]] = defaultdict(list)
51
+ for i, g in enumerate(spawn_groups):
52
+ groups[g].append((i, rewards[i]))
53
+ result = list(rewards)
54
+ for g, entries in groups.items():
55
+ if len(entries) < 2:
56
+ continue
57
+ vals = [v for _, v in entries]
58
+ gmean = sum(vals) / len(vals)
59
+ for idx, _ in entries:
60
+ result[idx] -= gmean
61
+ return result
62
+
63
+
64
+ def _rank_normalize(values: list[float]) -> list[float]:
65
+ """Map values to [-1, +1] via rank normalization with tie handling.
66
+
67
+ Robust to outliers — one amazing episode doesn't compress the rest.
68
+ Guarantees equal spacing: best episode ALWAYS gets +1.0, worst -1.0.
69
+ """
70
+ n = len(values)
71
+ if n < 2:
72
+ return [0.0] * n
73
+ sorted_indices = sorted(range(n), key=lambda i: values[i])
74
+ ranks = [0.0] * n
75
+ i = 0
76
+ while i < n:
77
+ j = i
78
+ while j < n - 1 and values[sorted_indices[j + 1]] == values[sorted_indices[j]]:
79
+ j += 1
80
+ avg_rank = (i + j) / 2.0 + 1.0 # 1-based
81
+ for k in range(i, j + 1):
82
+ ranks[sorted_indices[k]] = avg_rank
83
+ i = j + 1
84
+ return [2.0 * (r - 1.0) / (n - 1.0) - 1.0 for r in ranks]
85
+
86
+
87
+ def _zscore_batch(values: list[float]) -> list[float]:
88
+ """Rank-normalize within batch.
89
+
90
+ Replaces z-score: rank normalization is robust to outliers and
91
+ guarantees even advantage spacing regardless of score distribution.
92
+ """
93
+ return _rank_normalize(values)
94
+
95
+
96
+ def _zscore_per_group(values: list[float], spawn_groups: list[int] | None) -> list[float]:
97
+ """Rank-normalize within each spawn group.
98
+
99
+ Each spawn group = different map layout = different "prompt".
100
+ Rank-normalizing per group ensures advantages reflect behavioral
101
+ differences within the SAME conditions, not map difficulty.
102
+
103
+ Falls back to global rank normalization if no spawn groups provided.
104
+ """
105
+ if not values or len(values) < 2:
106
+ return values
107
+
108
+ if not spawn_groups:
109
+ return _rank_normalize(values)
110
+
111
+ groups: dict[int, list[int]] = {}
112
+ for i, g in enumerate(spawn_groups):
113
+ groups.setdefault(g, []).append(i)
114
+
115
+ result = list(values)
116
+ for indices in groups.values():
117
+ if len(indices) < 2:
118
+ result[indices[0]] = 0.0
119
+ continue
120
+ group_vals = [values[i] for i in indices]
121
+ ranked = _rank_normalize(group_vals)
122
+ for idx, rank_val in zip(indices, ranked):
123
+ result[idx] = rank_val
124
+ return result
125
+
126
+
127
+ def _neutralize_infra(rewards: list[float], kwargs: dict) -> list[float]:
128
+ """Replace infra-failure and tool-call-failure episode rewards with valid-episode mean.
129
+
130
+ DAPO-style dynamic sampling (arXiv:2503.14476 Section 3.1): episodes
131
+ that failed due to infrastructure issues (game server crash, vLLM 500
132
+ errors) or tool call degeneration (model produced gibberish instead of
133
+ tool calls) get their reward set to the batch mean of valid episodes.
134
+ After GRPO normalization: advantage = (mean - mean) / std = 0,
135
+ so these episodes contribute zero gradient.
136
+ """
137
+ infra = kwargs.get("infra_failure", [])
138
+ tool_fail = kwargs.get("tool_call_failure", [])
139
+ n = len(rewards)
140
+ # Build combined failure mask
141
+ failed = [False] * n
142
+ for i in range(n):
143
+ if (i < len(infra) and infra[i]) or (i < len(tool_fail) and tool_fail[i]):
144
+ failed[i] = True
145
+ if not any(failed):
146
+ return rewards
147
+ valid = [r for r, f in zip(rewards, failed) if not f]
148
+ if not valid:
149
+ return rewards # all failed — nothing to anchor on
150
+ vmean = sum(valid) / len(valid)
151
+ return [vmean if failed[i] else r for i, r in enumerate(rewards)]
152
+
153
+
154
+ def reward_outcome(completions: list[str], **kwargs) -> list[float]:
155
+ """Terminal game outcome: +1.0 win, -1.0 lose, 0.0 draw/incomplete."""
156
+ outcomes = kwargs.get("outcome", [])
157
+ if not outcomes:
158
+ base = [0.0] * len(completions)
159
+ else:
160
+ mapping = {"win": 1.0, "lose": -1.0, "draw": 0.0}
161
+ base = [mapping.get(o, 0.0) for o in outcomes]
162
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
163
+ weighted = _apply_weights(normalized, "outcome", kwargs.get("scenario_weights"))
164
+ return _neutralize_infra(weighted, kwargs)
165
+
166
+
167
+ def reward_combat(completions: list[str], **kwargs) -> list[float]:
168
+ """Combat efficiency from the 8-dim reward vector."""
169
+ scores = kwargs.get("combat_score", [])
170
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
171
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
172
+ weighted = _apply_weights(normalized, "combat", kwargs.get("scenario_weights"))
173
+ return _neutralize_infra(weighted, kwargs)
174
+
175
+
176
+ def reward_economy(completions: list[str], **kwargs) -> list[float]:
177
+ """Economic performance from the 8-dim reward vector."""
178
+ scores = kwargs.get("economy_score", [])
179
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
180
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
181
+ weighted = _apply_weights(normalized, "economy", kwargs.get("scenario_weights"))
182
+ return _neutralize_infra(weighted, kwargs)
183
+
184
+
185
+ def reward_tempo(completions: list[str], **kwargs) -> list[float]:
186
+ """Action efficiency — fewer redundant actions = higher reward.
187
+
188
+ Tempo IS spawn-correlated (r=0.74 with discovery in Sprint scenario):
189
+ closer spawns → less travel time → better tempo. Apply spawn-group
190
+ normalization to isolate the behavioral component.
191
+ """
192
+ scores = kwargs.get("tempo_score", [])
193
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
194
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
195
+ weighted = _apply_weights(normalized, "tempo", kwargs.get("scenario_weights"))
196
+ return _neutralize_infra(weighted, kwargs)
197
+
198
+
199
+ def reward_density(completions: list[str], **kwargs) -> list[float]:
200
+ """Action density — parallel utilization of controllable resources.
201
+
202
+ Measures how many distinct objectives are pursued per turn relative
203
+ to available units. Independent of tempo (which measures activity/idle).
204
+ 3 units with 3 separate commands to 3 places → high density.
205
+ 3 units with 1 blob command → low density.
206
+ """
207
+ scores = kwargs.get("density_score", [])
208
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
209
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
210
+ weighted = _apply_weights(normalized, "density", kwargs.get("scenario_weights"))
211
+ return _neutralize_infra(weighted, kwargs)
212
+
213
+
214
+ def reward_format(completions: list[str], **kwargs) -> list[float]:
215
+ """Format compliance — fraction of turns with valid structured action syntax."""
216
+ scores = kwargs.get("format_score", [])
217
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
218
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
219
+ weighted = _apply_weights(normalized, "format", kwargs.get("scenario_weights"))
220
+ return _neutralize_infra(weighted, kwargs)
221
+
222
+
223
+ def reward_survival(completions: list[str], **kwargs) -> list[float]:
224
+ """Unit HP preservation — discourages suicide attacks."""
225
+ scores = kwargs.get("survival_score", [])
226
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
227
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
228
+ weighted = _apply_weights(normalized, "survival", kwargs.get("scenario_weights"))
229
+ return _neutralize_infra(weighted, kwargs)
230
+
231
+
232
+ def reward_discovery(completions: list[str], **kwargs) -> list[float]:
233
+ """Discovery reward — accumulated intelligence score from scouting.
234
+
235
+ The game engine awards 0.05 per new enemy unit sighting + bonuses for
236
+ buildings (0.2 production, 0.5 base). Values are accumulated across all
237
+ ticks and clamped to [0, 1].
238
+ """
239
+ scores = kwargs.get("discovery_score", [])
240
+ if not scores:
241
+ base = [0.0] * len(completions)
242
+ else:
243
+ base = [min(max(float(s), 0.0), 1.0) for s in scores]
244
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
245
+ weighted = _apply_weights(normalized, "discovery", kwargs.get("scenario_weights"))
246
+ return _neutralize_infra(weighted, kwargs)
247
+
248
+
249
+ def reward_disruption(completions: list[str], **kwargs) -> list[float]:
250
+ """Strategic sabotage — destroying enemy power, production, tech."""
251
+ scores = kwargs.get("disruption_score", [])
252
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
253
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
254
+ weighted = _apply_weights(normalized, "disruption", kwargs.get("scenario_weights"))
255
+ return _neutralize_infra(weighted, kwargs)
256
+
257
+
258
+ def reward_exploration(completions: list[str], **kwargs) -> list[float]:
259
+ """Map exploration percentage — rewards fog-of-war clearing."""
260
+ scores = kwargs.get("exploration_score", [])
261
+ base = [float(s) for s in scores] if scores else [0.0] * len(completions)
262
+ normalized = _zscore_per_group(base, kwargs.get("spawn_group"))
263
+ weighted = _apply_weights(normalized, "exploration", kwargs.get("scenario_weights"))
264
+ return _neutralize_infra(weighted, kwargs)
openra_rl_training/training/rust_env_pool.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pool of Rust-backed OpenRA environments for fast in-process rollout.
2
+
3
+ Mirrors the surface of `env_pool.EnvPool` but swaps the gRPC-backed
4
+ `OpenRAEnvironment` for the native `openra_train.OpenRAEnv` (a Rust
5
+ deterministic simulator built via maturin/PyO3).
6
+
7
+ Key differences from the gRPC pool:
8
+ * No game server / port allocation. Each `OpenRAEnv` is a
9
+ self-contained Rust object — instantiation is microseconds.
10
+ * Episodes are deterministic given (scenario_path, seed). The pool
11
+ accepts a `seed_generator` (defaults to a monotonic counter) so
12
+ callers can reseed each acquire if desired.
13
+ * `step` accepts a list of `openra_train.Command` objects (build
14
+ them with `Command.move_units(...)`, `Command.attack_unit(...)`,
15
+ `Command.observe()`).
16
+
17
+ The pool is process-local; for honest parallelism, fan out via
18
+ `concurrent.futures.ProcessPoolExecutor` and have each worker own its
19
+ own `RustEnvPool` (or just instantiate `OpenRAEnv` directly).
20
+
21
+ Drop-in for the existing `env_pool.EnvPool`:
22
+ * `acquire(timeout=...) -> env`
23
+ * `release(env)`
24
+ * `update_scenario(path)` — refreshes the default scenario for new
25
+ envs and resets the seed counter.
26
+ * `shutdown()` — drops references; Rust GC frees the worlds.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import itertools
32
+ import logging
33
+ import queue
34
+ import threading
35
+ from typing import Any, Callable, Iterator
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ def _default_seed_generator(start: int = 0) -> Iterator[int]:
41
+ return itertools.count(start)
42
+
43
+
44
+ class RustEnvHandle:
45
+ """Thin wrapper to give the Rust env a uniform `reset / step / close`
46
+ interface that mirrors the gRPC env without leaking PyO3 types
47
+ upward."""
48
+
49
+ def __init__(self, scenario_path: str, seed: int):
50
+ # Lazy import so import-time failures don't break the rest of
51
+ # the training package on machines without the wheel built.
52
+ import openra_train
53
+
54
+ self._cls_command = openra_train.Command
55
+ self._env = openra_train.OpenRAEnv(scenario_path, int(seed))
56
+ self.scenario_path = scenario_path
57
+ self.seed = int(seed)
58
+
59
+ @property
60
+ def Command(self):
61
+ """Expose `openra_train.Command` for callers that want to
62
+ construct Move/Attack/Observe payloads without re-importing."""
63
+ return self._cls_command
64
+
65
+ def reset(self, seed: int | None = None) -> dict[str, Any]:
66
+ if seed is not None and int(seed) != self.seed:
67
+ # Re-instantiate to pick up the new seed (the underlying
68
+ # Rust env owns the world; reset() re-uses the original
69
+ # seed). Cheap — Rust instantiation is sub-millisecond.
70
+ import openra_train
71
+ self._env = openra_train.OpenRAEnv(self.scenario_path, int(seed))
72
+ self.seed = int(seed)
73
+ return self._env.reset()
74
+
75
+ def step(self, commands: list[Any]) -> tuple[dict[str, Any], float, bool, dict[str, Any]]:
76
+ """Apply a list of `openra_train.Command` objects, returns
77
+ (obs, reward, done, info)."""
78
+ return self._env.step(commands)
79
+
80
+ def close(self) -> None:
81
+ # No external resources to release; the Rust world is freed
82
+ # when this handle is dropped.
83
+ self._env = None
84
+
85
+
86
+ class RustEnvPool:
87
+ """Thread-safe pool of Rust-backed environments.
88
+
89
+ Args:
90
+ size: Number of environment instances.
91
+ scenario_path: Path to the rush-hour-style scenario YAML.
92
+ seed_generator: Iterator yielding seeds for each new env. If
93
+ None, defaults to `itertools.count(0)`.
94
+ env_factory: Optional override; receives `(scenario_path, seed)`
95
+ and returns an env-like object exposing `reset(...)` /
96
+ `step(...)`. Useful for testing.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ size: int = 4,
102
+ scenario_path: str = "",
103
+ seed_generator: Iterator[int] | None = None,
104
+ env_factory: Callable[[str, int], Any] | None = None,
105
+ ):
106
+ if size < 1:
107
+ raise ValueError(f"RustEnvPool size must be >=1, got {size}")
108
+ if not scenario_path:
109
+ raise ValueError("RustEnvPool requires a non-empty scenario_path")
110
+
111
+ self._size = size
112
+ self._scenario_path = scenario_path
113
+ self._seed_gen = seed_generator or _default_seed_generator()
114
+ self._factory = env_factory or (lambda path, seed: RustEnvHandle(path, seed))
115
+ self._pool: queue.Queue = queue.Queue()
116
+ self._envs: list = []
117
+ self._lock = threading.Lock()
118
+
119
+ for _ in range(size):
120
+ seed = next(self._seed_gen)
121
+ env = self._factory(scenario_path, seed)
122
+ self._envs.append(env)
123
+ self._pool.put(env)
124
+
125
+ def acquire(self, timeout: float = 30.0):
126
+ """Get an available environment (blocks if all busy).
127
+
128
+ The Rust env is in-process and deterministic, so the timeout
129
+ only applies if all envs are checked out by other threads.
130
+ """
131
+ return self._pool.get(timeout=timeout)
132
+
133
+ def release(self, env) -> None:
134
+ """Return an environment to the pool."""
135
+ self._pool.put(env)
136
+
137
+ def update_scenario(self, scenario_path: str) -> None:
138
+ """Replace the scenario used for newly-instantiated envs.
139
+
140
+ Existing pooled envs keep their current scenario until released
141
+ and re-acquired with `acquire(reset=True)` (callers should use
142
+ this method in conjunction with explicit env replacement).
143
+ """
144
+ with self._lock:
145
+ self._scenario_path = scenario_path
146
+
147
+ @property
148
+ def scenario_path(self) -> str:
149
+ return self._scenario_path
150
+
151
+ @property
152
+ def size(self) -> int:
153
+ return self._size
154
+
155
+ @property
156
+ def available(self) -> int:
157
+ return self._pool.qsize()
158
+
159
+ def shutdown(self) -> None:
160
+ """Drop all env references and drain the pool."""
161
+ with self._lock:
162
+ for env in self._envs:
163
+ try:
164
+ if hasattr(env, "close"):
165
+ env.close()
166
+ except Exception:
167
+ logger.exception("Error closing Rust env")
168
+ self._envs.clear()
169
+ while not self._pool.empty():
170
+ try:
171
+ self._pool.get_nowait()
172
+ except queue.Empty:
173
+ break
requirements.txt CHANGED
@@ -3,3 +3,10 @@ pandas>=2.0.0
3
  httpx>=0.24.0
4
  huggingface_hub>=0.20.0
5
  openra-rl-util>=0.1.0
 
 
 
 
 
 
 
 
3
  httpx>=0.24.0
4
  huggingface_hub>=0.20.0
5
  openra-rl-util>=0.1.0
6
+ # Used by the bench + the vendored openra_rl_training / openra_env
7
+ # modules (see VENDOR.md).
8
+ pydantic>=2.0
9
+ pyyaml>=6.0
10
+ pillow>=10.0
11
+ numpy>=1.24
12
+ matplotlib>=3.7