mouse-example-model / README.md
micahr234's picture
Upload MOUSE model
51a74ea verified
|
Raw
History Blame Contribute Delete
2.01 kB
---
library_name: mouse-core
tags:
- mouse-core
- reinforcement-learning
---
# micahr234/mouse-example-model
This repository contains a MOUSE model checkpoint.
## Architecture
- Backbone: `qwen3`
- Hidden dimension: `1024`
- Heads: `action_value_layerwise`
- Action head: `action_value_layerwise`
### Encoder
`StepEmbedder` reads flat step-record dicts and projects each declared modality
into the shared `1024`-dimensional token space before the
backbone.
| Field | Type | Required | Tensor shape | Dtype | Notes |
|---|---|---:|---|---|---|
| `action` | `discrete` | yes | `[B, S]` | `torch.long` | integer ids in `[0, 3]` |
| `observation` | `discrete` | yes | `[B, S]` | `torch.long` | integer ids in `[0, 63]` |
| `reward` | `rff` | yes | `[B, S]` | `torch.float32` | scalar value |
| `done` | `discrete` | yes | `[B, S]` | `torch.long` | integer ids in `[0, 4]` |
## Install MouseCore
```bash
pip install mouse-core
```
## Load The Model
```python
import torch
from mouse_core import load_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("micahr234/mouse-example-model", map_location="cpu").eval().to(device)
```
## Run Inference
The model accepts a `list[list[dict]]` batch of shape `[B][S]` — B sequences,
each containing S step-record dicts with flat keys matching the encoder's
declared modalities above.
```python
# Batch shape: [B=1][S=1] — one sequence of one step.
batch = [[
{
"action": 0,
"observation": 0,
"reward": 0.0,
"done": 0,
}
]]
predictions, objective_data, cache = model(batch)
with torch.no_grad():
predictions, _, cache = model(batch)
action = model.get_action(predictions, temperature=0.0)
```
`model()` returns `(predictions, objective_data, cache)`. `objective_data` is a
`TensorDict[B, S]` of the modality tensors extracted by the encoder — pass it
to objectives during training. For cached one-step rollout, keep `cache` and
pass it back on the next call with `use_cache=True`.