File size: 2,299 Bytes
427cea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
---
license: mit
tags:
  - reinforcement-learning
  - minihack
  - diffusion
  - planning
  - behavior-cloning
---

# ReMDM-MiniHack

Generative Planning Agent for MiniHack navigation using **Re-Masked Discrete Diffusion (ReMDM)**.

The agent uses Masked Discrete Diffusion to iteratively generate action sequences for dungeon navigation.
Instead of predicting the next action autoregressively, the model generates entire 64-step trajectories
by progressively unmasking action tokens.

## Code

GitHub: [piotrwilam/ReMDM-MiniHack-Project](https://github.com/piotrwilam/ReMDM-MiniHack-Project)

## Models

| Version | Model | Params | Training | Tag |
|---|---|---|---|---|
| v017_local_baseline | LocalDiffusionPlanner | 7M | Offline BC, 200 demos/env, 30 epochs | β€” |
| v017_local_baseline | LocalDiffusionPlanner | 7M | Offline BC, 500 demos/env, 60 epochs | `v0.17-local-baseline-gold` (pending) |

## Repo Structure

```
ReMDM-MiniHack/
β”œβ”€β”€ README.md                              # This file
β”œβ”€β”€ v017_local_baseline/
β”‚   β”œβ”€β”€ inference_weights.pth              # EMA state dict (for evaluation)
β”‚   β”œβ”€β”€ full_checkpoint.pth                # Full training state (for resuming)
β”‚   β”œβ”€β”€ config.json                        # Hyperparams + model args
β”‚   └── eval_results.csv                   # Per-environment results
└── datasets/
    └── oracle_demos_v017.pt               # Oracle demonstration dataset
```

## Quick Start

```python
import torch
from huggingface_hub import hf_hub_download

# Download weights
path = hf_hub_download("piotrwilam/ReMDM-MiniHack", "v017_local_baseline/inference_weights.pth")
weights = torch.load(path, map_location="cpu", weights_only=False)

# Load model
from model import LocalDiffusionPlanner
model = LocalDiffusionPlanner(action_dim=12)
model.load_state_dict(weights)
model.eval()
```

## Results: v017 Local Baseline (Offline BC, 200 demos/env, 30 epochs)

| Environment | Win% | Avg Steps |
|---|---|---|
| Room-Random-5x5 | 94% | 18.3 |
| Room-Random-15x15 | 54% | 130.4 |
| Room-Dark-5x5 | 90% | 25.5 |
| Room-Ultimate-5x5 | 84% | 20.8 |
| Room-Ultimate-15x15 | 30% | 72.1 |
| Corridor-R2 | 42% | 132.1 |
| Corridor-R3 | 0% | 200.0 |
| MazeWalk-9x9 | 48% | 119.0 |
| MazeWalk-15x15 | 22% | 162.3 |