Initial repo structure with model card
Browse files- README.md +73 -0
- datasets/.gitkeep +0 -0
- v017_local_baseline/.gitkeep +0 -0
README.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- reinforcement-learning
|
| 5 |
+
- minihack
|
| 6 |
+
- diffusion
|
| 7 |
+
- planning
|
| 8 |
+
- behavior-cloning
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# ReMDM-MiniHack
|
| 12 |
+
|
| 13 |
+
Generative Planning Agent for MiniHack navigation using **Re-Masked Discrete Diffusion (ReMDM)**.
|
| 14 |
+
|
| 15 |
+
The agent uses Masked Discrete Diffusion to iteratively generate action sequences for dungeon navigation.
|
| 16 |
+
Instead of predicting the next action autoregressively, the model generates entire 64-step trajectories
|
| 17 |
+
by progressively unmasking action tokens.
|
| 18 |
+
|
| 19 |
+
## Code
|
| 20 |
+
|
| 21 |
+
GitHub: [piotrwilam/ReMDM-MiniHack-Project](https://github.com/piotrwilam/ReMDM-MiniHack-Project)
|
| 22 |
+
|
| 23 |
+
## Models
|
| 24 |
+
|
| 25 |
+
| Version | Model | Params | Training | Tag |
|
| 26 |
+
|---|---|---|---|---|
|
| 27 |
+
| v017_local_baseline | LocalDiffusionPlanner | 7M | Offline BC, 200 demos/env, 30 epochs | β |
|
| 28 |
+
| v017_local_baseline | LocalDiffusionPlanner | 7M | Offline BC, 500 demos/env, 60 epochs | `v0.17-local-baseline-gold` (pending) |
|
| 29 |
+
|
| 30 |
+
## Repo Structure
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
ReMDM-MiniHack/
|
| 34 |
+
βββ README.md # This file
|
| 35 |
+
βββ v017_local_baseline/
|
| 36 |
+
β βββ inference_weights.pth # EMA state dict (for evaluation)
|
| 37 |
+
β βββ full_checkpoint.pth # Full training state (for resuming)
|
| 38 |
+
β βββ config.json # Hyperparams + model args
|
| 39 |
+
β βββ eval_results.csv # Per-environment results
|
| 40 |
+
βββ datasets/
|
| 41 |
+
βββ oracle_demos_v017.pt # Oracle demonstration dataset
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Quick Start
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
import torch
|
| 48 |
+
from huggingface_hub import hf_hub_download
|
| 49 |
+
|
| 50 |
+
# Download weights
|
| 51 |
+
path = hf_hub_download("piotrwilam/ReMDM-MiniHack", "v017_local_baseline/inference_weights.pth")
|
| 52 |
+
weights = torch.load(path, map_location="cpu", weights_only=False)
|
| 53 |
+
|
| 54 |
+
# Load model
|
| 55 |
+
from model import LocalDiffusionPlanner
|
| 56 |
+
model = LocalDiffusionPlanner(action_dim=12)
|
| 57 |
+
model.load_state_dict(weights)
|
| 58 |
+
model.eval()
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Results: v017 Local Baseline (Offline BC, 200 demos/env, 30 epochs)
|
| 62 |
+
|
| 63 |
+
| Environment | Win% | Avg Steps |
|
| 64 |
+
|---|---|---|
|
| 65 |
+
| Room-Random-5x5 | 94% | 18.3 |
|
| 66 |
+
| Room-Random-15x15 | 54% | 130.4 |
|
| 67 |
+
| Room-Dark-5x5 | 90% | 25.5 |
|
| 68 |
+
| Room-Ultimate-5x5 | 84% | 20.8 |
|
| 69 |
+
| Room-Ultimate-15x15 | 30% | 72.1 |
|
| 70 |
+
| Corridor-R2 | 42% | 132.1 |
|
| 71 |
+
| Corridor-R3 | 0% | 200.0 |
|
| 72 |
+
| MazeWalk-9x9 | 48% | 119.0 |
|
| 73 |
+
| MazeWalk-15x15 | 22% | 162.3 |
|
datasets/.gitkeep
ADDED
|
File without changes
|
v017_local_baseline/.gitkeep
ADDED
|
File without changes
|