code3939 commited on
Commit
5ceea00
·
verified ·
1 Parent(s): 750207d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +109 -3
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - reinforcement-learning
5
+ - offline-rl
6
+ - decision-transformer
7
+ - unity-ml-agents
8
+ - robotics
9
+ - sim-to-real
10
+ datasets:
11
+ - DecisionTransformer-Unity-Sim/DTTrajectoryData.zip
12
+ ---
13
+
14
+ # Decision Transformer for Dynamic 3D Environments via Strategic Data Curation
15
+
16
+ This repository contains the official implementation and pre-trained models for the paper "[Data-Centric Offline Reinforcement Learning: Strategic Data Curation via Unity ML-Agents and Decision Transformer]" (Submitted to Scientific Reports).
17
+
18
+ We present a data-centric approach to Offline Reinforcement Learning (Offline RL) using **Unity ML-Agents** and **Decision Transformer (DT)**. Our research demonstrates that **strategic data curation**—specifically, fine-tuning on a small subset of high-quality "virtual expert" trajectories—is more critical for performance optimization than mere data volume.
19
+
20
+ ## 🚀 Key Features
21
+ * **Sim-to-Data-to-Model:** A complete pipeline generating synthetic data via Unity ML-Agents to train Transformer-based control agents.
22
+ * **Strategic Curation:** Demonstrates that fine-tuning with only **5-10%** of high-quality data (Top-tier trajectories) significantly outperforms training on massive mixed-quality datasets.
23
+ * **Robust Generalization:** The model maintains **96-100%** success rates even in zero-shot environments with increased complexity (e.g., 20 simultaneous targets).
24
+
25
+ ## 📊 Model Zoo
26
+
27
+ | Model Name | Pre-training Data | Fine-tuning Data | Description |
28
+ | :--- | :--- | :--- | :--- |
29
+ | **DT_S_100** | 100% Mixed Data | None | Baseline model trained on the full dataset without curation. |
30
+ | **DT_C_5** | None | Top 5% Expert Data | Model trained *only* on a small, high-quality subset. |
31
+ | **DT_C_10** | None | Top 10% Expert Data | Model trained *only* on a larger high-quality subset. |
32
+ | **DT_SC_5** | 100% Mixed Data | Top 5% Expert Data | Pre-trained on mixed data, fine-tuned on top 5% curated data. |
33
+ | **DT_SC_10** | 100% Mixed Data | Top 10% Expert Data | **(Best)** Pre-trained on mixed data, fine-tuned on top 10% curated data. Achieves 4x stability. |
34
+
35
+ ## 🏗️ Methodology
36
+ 1. **Data Generation:** We utilized **Unity ML-Agents** to train a PPO (Proximal Policy Optimization) agent as a "Virtual Expert."
37
+ 2. **Data Collection:** Collected step-wise interaction data (State, Action, Reward, RTG) from the PPO agent in a 3D projectile interception task. Supported by scripts in `UnityScript/`.
38
+ 3. **Offline Training:** Trained a **Decision Transformer** (Chen et al., 2021) to predict the next optimal action based on the history of states and target returns. Implemented in `model_dt.py`.
39
+
40
+ ## 📈 Performance
41
+ * **Control Stability:** Improved by **3.5x** in the `DT_SC` model compared to the baseline.
42
+ * **Firing Stability:** Improved by over **4x**.
43
+ * **Success Rate:** Maintained PPO-level performance (~98%) while strictly operating in an offline manner.
44
+ * **Metrics Visualization:** Use `chart_visualize.py` to reproduce performance plots (Win Rate, Avg Steps, Smoothness).
45
+
46
+ ## 💻 Usage
47
+
48
+ The following example demonstrates how to load a pre-trained model and run inference:
49
+
50
+ ```python
51
+ import torch
52
+ import numpy as np
53
+ from model_dt import DecisionTransformer
54
+
55
+ # Configuration (must match training config)
56
+ OBS_DIM = 9
57
+ ACT_DIM = 3
58
+ HIDDEN_SIZE = 256
59
+ MAX_LEN = 1024 # Sequence length
60
+
61
+ # 1. Load the pre-trained model
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ model = DecisionTransformer(
64
+ obs_dim=OBS_DIM,
65
+ act_dim=ACT_DIM,
66
+ hidden=HIDDEN_SIZE,
67
+ max_len=MAX_LEN
68
+ )
69
+
70
+ # Load weights (example: DT_SC_5.pth)
71
+ model_path = "DT_SC_5.pth"
72
+ model.load_state_dict(torch.load(model_path, map_location=device))
73
+ model.to(device)
74
+ model.eval()
75
+
76
+ print(f"Loaded model from {model_path}")
77
+
78
+ # 2. Inference Loop (Pseudo-code example)
79
+ # Note: Requires a running environment 'env'
80
+ def get_action(model, states, actions, rewards, target_return, timesteps):
81
+ # Pad all inputs to context length (MAX_LEN) if necessary
82
+ # ... (Padding logic here) ...
83
+
84
+ with torch.no_grad():
85
+ # Predict action
86
+ state_preds = model(
87
+ states.unsqueeze(0),
88
+ actions.unsqueeze(0),
89
+ rewards.unsqueeze(0),
90
+ timesteps.unsqueeze(0)
91
+ )
92
+ action_pred = state_preds[0, -1] # Take the last action prediction
93
+ return action_pred
94
+
95
+ # Example usage within an episode
96
+ # state = env.reset()
97
+ # target_return = torch.tensor([1.0], device=device) # Normalized expert return
98
+ # for t in range(max_steps):
99
+ # action = get_action(model, state_history, action_history, reward_history, target_return, t)
100
+ # next_state, reward, done, _ = env.step(action)
101
+ # ...
102
+ ```
103
+
104
+ ## 📁 File Structure
105
+ - `model_dt.py`: Decision Transformer model definition.
106
+ - `train_sequential.py`: Main training script.
107
+ - `dataset_dt.py`: Dataset loader for trajectory data.
108
+ - `chart_visualize.py`: Visualization tool for benchmark metrics.
109
+ - `UnityScript/`: C# scripts for Unity ML-Agents environment.