|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- reinforcement-learning |
|
|
- offline-rl |
|
|
- decision-transformer |
|
|
- unity-ml-agents |
|
|
- robotics |
|
|
- sim-to-real |
|
|
datasets: |
|
|
- DecisionTransformer-Unity-Sim/DTTrajectoryData.zip |
|
|
--- |
|
|
|
|
|
# Decision Transformer for Dynamic 3D Environments via Strategic Data Curation |
|
|
|
|
|
This repository contains the official implementation and pre-trained models for the paper "[Data-Centric Offline Reinforcement Learning: Strategic Data Curation via Unity ML-Agents and Decision Transformer]" (Submitted to Scientific Reports). |
|
|
|
|
|
We present a data-centric approach to Offline Reinforcement Learning (Offline RL) using **Unity ML-Agents** and **Decision Transformer (DT)**. Our research demonstrates that **strategic data curation**—specifically, fine-tuning on a small subset of high-quality "virtual expert" trajectories—is more critical for performance optimization than mere data volume. |
|
|
|
|
|
## 🚀 Key Features |
|
|
* **Sim-to-Data-to-Model:** A complete pipeline generating synthetic data via Unity ML-Agents to train Transformer-based control agents. |
|
|
* **Strategic Curation:** Demonstrates that fine-tuning with only **5-10%** of high-quality data (Top-tier trajectories) significantly outperforms training on massive mixed-quality datasets. |
|
|
* **Robust Generalization:** The model maintains **96-100%** success rates even in zero-shot environments with increased complexity (e.g., 20 simultaneous targets). |
|
|
|
|
|
## 📊 Model Zoo |
|
|
|
|
|
| Model Name | Pre-training Data | Fine-tuning Data | Description | |
|
|
| :--- | :--- | :--- | :--- | |
|
|
| **DT_S_100** | 100% Mixed Data | None | Baseline model trained on the full dataset without curation. | |
|
|
| **DT_C_5** | None | Top 5% Expert Data | Model trained *only* on a small, high-quality subset. | |
|
|
| **DT_C_10** | None | Top 10% Expert Data | Model trained *only* on a larger high-quality subset. | |
|
|
| **DT_SC_5** | 100% Mixed Data | Top 5% Expert Data | Pre-trained on mixed data, fine-tuned on top 5% curated data. | |
|
|
| **DT_SC_10** | 100% Mixed Data | Top 10% Expert Data | **(Best)** Pre-trained on mixed data, fine-tuned on top 10% curated data. Achieves 4x stability. | |
|
|
|
|
|
## 🏗️ Methodology |
|
|
1. **Data Generation:** We utilized **Unity ML-Agents** to train a PPO (Proximal Policy Optimization) agent as a "Virtual Expert." |
|
|
2. **Data Collection:** Collected step-wise interaction data (State, Action, Reward, RTG) from the PPO agent in a 3D projectile interception task. Supported by scripts in `UnityScript/`. |
|
|
3. **Offline Training:** Trained a **Decision Transformer** (Chen et al., 2021) to predict the next optimal action based on the history of states and target returns. Implemented in `model_dt.py`. |
|
|
|
|
|
## 📈 Performance |
|
|
* **Control Stability:** Improved by **3.5x** in the `DT_SC` model compared to the baseline. |
|
|
* **Firing Stability:** Improved by over **4x**. |
|
|
* **Success Rate:** Maintained PPO-level performance (~98%) while strictly operating in an offline manner. |
|
|
* **Metrics Visualization:** Use `chart_visualize.py` to reproduce performance plots (Win Rate, Avg Steps, Smoothness). |
|
|
|
|
|
## 💻 Usage |
|
|
|
|
|
The following example demonstrates how to load a pre-trained model and run inference: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import numpy as np |
|
|
from model_dt import DecisionTransformer |
|
|
|
|
|
# Configuration (must match training config) |
|
|
OBS_DIM = 9 |
|
|
ACT_DIM = 3 |
|
|
HIDDEN_SIZE = 256 |
|
|
MAX_LEN = 1024 # Sequence length |
|
|
|
|
|
# 1. Load the pre-trained model |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = DecisionTransformer( |
|
|
obs_dim=OBS_DIM, |
|
|
act_dim=ACT_DIM, |
|
|
hidden=HIDDEN_SIZE, |
|
|
max_len=MAX_LEN |
|
|
) |
|
|
|
|
|
# Load weights (example: DT_SC_5.pth) |
|
|
model_path = "DT_SC_5.pth" |
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"Loaded model from {model_path}") |
|
|
|
|
|
# 2. Inference Loop (Pseudo-code example) |
|
|
# Note: Requires a running environment 'env' |
|
|
def get_action(model, states, actions, rewards, target_return, timesteps): |
|
|
# Pad all inputs to context length (MAX_LEN) if necessary |
|
|
# ... (Padding logic here) ... |
|
|
|
|
|
with torch.no_grad(): |
|
|
# Predict action |
|
|
state_preds = model( |
|
|
states.unsqueeze(0), |
|
|
actions.unsqueeze(0), |
|
|
rewards.unsqueeze(0), |
|
|
timesteps.unsqueeze(0) |
|
|
) |
|
|
action_pred = state_preds[0, -1] # Take the last action prediction |
|
|
return action_pred |
|
|
|
|
|
# Example usage within an episode |
|
|
# state = env.reset() |
|
|
# target_return = torch.tensor([1.0], device=device) # Normalized expert return |
|
|
# for t in range(max_steps): |
|
|
# action = get_action(model, state_history, action_history, reward_history, target_return, t) |
|
|
# next_state, reward, done, _ = env.step(action) |
|
|
# ... |
|
|
``` |
|
|
|
|
|
## 📁 File Structure |
|
|
- `model_dt.py`: Decision Transformer model definition. |
|
|
- `train_sequential.py`: Main training script. |
|
|
- `dataset_dt.py`: Dataset loader for trajectory data. |
|
|
- `chart_visualize.py`: Visualization tool for benchmark metrics. |