| ---
|
| license: other
|
| license_name: polyform-noncommercial-1.0.0
|
| license_link: https://polyformproject.org/licenses/noncommercial/1.0.0/
|
| library_name: safetensors
|
| tags:
|
| - reinforcement-learning
|
| - offline-rl
|
| - mujoco
|
| - gpt
|
| - llama
|
| - autoregressive
|
| - causal-gpt-rl
|
| ---
|
|
|
| # Causal GPT-RL
|
|
|
| **First decoder-only transformer (GPT) to reach expert level on Humanoid offline RL from medium data — beyond what Behavior Cloning and Decision Transformer can achieve.**
|
|
|
| GPT-style transformers (GPT-2, Llama) running as RL policies in continuous-control environments.
|
|
|
| ```text
|
| action → next state → next action (RL rollouts)
|
| token → next token → next token (LLM generation)
|
| ```
|
|
|
| Stable under self-generated rollouts — long-horizon control without the drift that has historically kept transformers from being usable as RL agents.
|
|
|
| ## Bundles in this repository
|
|
|
| | Environment | Subfolder | Context length | Return (mean ± std) |
|
| |---|---|---|---|
|
| | Ant-v5 | `ant-v5` | 16 | 2614 ± 1515 |
|
| | HalfCheetah-v5 | `halfcheetah-v5` | 32 | 3251 ± 1916 |
|
| | Walker2d-v5 | `walker2d-v5` | 24 | 2345 ± 879 |
|
| | Humanoid-v5 | `humanoid-v5` | 32 | 2371 ± 2850 |
|
|
|
| Returns are over 5 episodes with `seed=0` (HalfCheetah-v5: 50 episodes), run on CPU via `run_episodes`.
|
|
|
| ## Quick Start
|
|
|
| ```bash
|
| pip install "causal-gpt-rl[hub,mujoco]"
|
| ```
|
|
|
| ```python
|
| import gymnasium as gym
|
| from causal_gpt_rl.inference import load_runner_from_hub, run_episodes
|
|
|
| env = gym.make("Ant-v5")
|
| runner = load_runner_from_hub(
|
| repo_id="ccnets/causal-gpt-rl",
|
| subfolder="ant-v5",
|
| )
|
| stats = run_episodes(env, runner, num_episodes=5, seed=0)
|
| print(stats["return_mean"], stats["return_std"])
|
| ```
|
|
|
| ## Bundle contents
|
|
|
| Each subfolder contains:
|
|
|
| - `model.safetensors` — model state dict for inference
|
| - `config.json` — model config, observation specs, action specs, context length
|
| - `state_normalizer.safetensors` — state normalization statistics
|
|
|
| ## Model details
|
|
|
| Llama-style transformer decoder, 4 layers, 8 heads. Hidden size 192 for Ant/HalfCheetah/Walker2d, 256 for Humanoid.
|
|
|
| ## Training data
|
|
|
| [Minari](https://minari.farama.org/) `mujoco/{env}/simple-v0` + `mujoco/{env}/medium-v0` per environment (expert split not used).
|
|
|
| ## Links
|
|
|
| - **Code:** [github.com/ccnets-team/causal-gpt-rl](https://github.com/ccnets-team/causal-gpt-rl)
|
| - **Training logs (W&B):** [wandb.ai/junhopark/Causal GPT-RL](https://wandb.ai/junhopark/Causal%20GPT-RL)
|
| - **Website:** [ccnets.org](https://ccnets.org)
|
|
|
| ## License
|
|
|
| PolyForm Noncommercial License 1.0.0. For commercial use, contact via [ccnets.org](https://ccnets.org).
|
|
|