Upload 11 files
Browse files- .gitattributes +1 -0
- BPD.py +124 -0
- README.md +132 -13
- SGVLB.py +27 -0
- gif_for_readme.gif +3 -0
- layer.py +56 -0
- logger.py +74 -0
- main.py +90 -0
- network.py +106 -0
- replay_memory.py +39 -0
- teacher_buffer/tmp +0 -0
- utils.py +142 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
gif_for_readme.gif filter=lfs diff=lfs merge=lfs -text
|
BPD.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from SGVLB import SGVLB
|
| 5 |
+
from network import Net, Critic
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BPDAgent(object):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
env,
|
| 12 |
+
args,
|
| 13 |
+
env_info,
|
| 14 |
+
thresholds,
|
| 15 |
+
datasize,
|
| 16 |
+
device,
|
| 17 |
+
discount,
|
| 18 |
+
tau,
|
| 19 |
+
noise_clip,
|
| 20 |
+
policy_freq,
|
| 21 |
+
h,
|
| 22 |
+
num_teacher_param,
|
| 23 |
+
):
|
| 24 |
+
self.args = args
|
| 25 |
+
self.env = env
|
| 26 |
+
self.env_info = env_info
|
| 27 |
+
|
| 28 |
+
self.actor = Net(env_info['state_dim'], env_info['action_dim'], env_info['action_bound'],
|
| 29 |
+
args.student_hidden_dims, thresholds['ALPHA_THRESHOLD'], thresholds['THETA_THRESHOLD'],
|
| 30 |
+
device=device).to(device)
|
| 31 |
+
self.actor_target = copy.deepcopy(self.actor)
|
| 32 |
+
self.sgvlb = SGVLB(self.actor, datasize, loss_type='l2', device=device)
|
| 33 |
+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
|
| 34 |
+
|
| 35 |
+
self.critic = Critic(env_info['state_dim'], env_info['action_dim']).to(device)
|
| 36 |
+
self.critic_target = copy.deepcopy(self.critic)
|
| 37 |
+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
|
| 38 |
+
|
| 39 |
+
self.discount = discount
|
| 40 |
+
self.tau = tau
|
| 41 |
+
self.noise_clip = noise_clip
|
| 42 |
+
self.policy_freq = policy_freq
|
| 43 |
+
self.datasize = datasize
|
| 44 |
+
self.h = h
|
| 45 |
+
|
| 46 |
+
self.total_it = 0
|
| 47 |
+
self.kl_weight = 0
|
| 48 |
+
|
| 49 |
+
def set_kl_weight(self, kl_weight):
|
| 50 |
+
self.kl_weight = kl_weight
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
def test(self):
|
| 54 |
+
self.actor.eval()
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
return_list = []
|
| 57 |
+
for epi_cnt in range(1, self.args.num_test_epi):
|
| 58 |
+
episode_return = 0
|
| 59 |
+
done = False
|
| 60 |
+
state, _ = self.env.reset()
|
| 61 |
+
while not done:
|
| 62 |
+
action = self.actor(state)
|
| 63 |
+
action = action.cpu().numpy()[0]
|
| 64 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 65 |
+
done = terminated or truncated
|
| 66 |
+
episode_return += reward
|
| 67 |
+
state = next_state
|
| 68 |
+
return_list.append(episode_return)
|
| 69 |
+
|
| 70 |
+
avg_return = sum(return_list) / len(return_list)
|
| 71 |
+
max_return = max(return_list)
|
| 72 |
+
min_return = min(return_list)
|
| 73 |
+
|
| 74 |
+
return avg_return, max_return, min_return
|
| 75 |
+
|
| 76 |
+
def train(self, transition):
|
| 77 |
+
self.actor.train()
|
| 78 |
+
|
| 79 |
+
self.total_it += 1
|
| 80 |
+
|
| 81 |
+
states, actions, rewards, next_states, dones = transition
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
next_actions = (
|
| 85 |
+
self.actor_target(next_states)
|
| 86 |
+
).clamp(self.env_info['action_bound'][0], self.env_info['action_bound'][1])
|
| 87 |
+
|
| 88 |
+
target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
|
| 89 |
+
target_Q = torch.min(target_Q1, target_Q2)
|
| 90 |
+
target_Q = rewards + (1 - dones) * self.discount * target_Q
|
| 91 |
+
|
| 92 |
+
current_Q1, current_Q2 = self.critic(states, actions)
|
| 93 |
+
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
|
| 94 |
+
|
| 95 |
+
self.critic_optimizer.zero_grad()
|
| 96 |
+
critic_loss.backward()
|
| 97 |
+
self.critic_optimizer.step()
|
| 98 |
+
|
| 99 |
+
if self.total_it % self.policy_freq == 0:
|
| 100 |
+
pi = self.actor(states)
|
| 101 |
+
Q = self.critic.Q1(states, pi)
|
| 102 |
+
lmbda = (self.h * self.datasize) / Q.abs().mean().detach()
|
| 103 |
+
|
| 104 |
+
actor_loss = -lmbda * Q.mean() + self.sgvlb(pi, actions, self.kl_weight) # lambda = h*|D|/avg(|Q|)
|
| 105 |
+
|
| 106 |
+
# Optimize the actor
|
| 107 |
+
self.actor_optimizer.zero_grad()
|
| 108 |
+
actor_loss.backward()
|
| 109 |
+
self.actor_optimizer.step()
|
| 110 |
+
|
| 111 |
+
# Update the frozen target models
|
| 112 |
+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
| 113 |
+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
| 114 |
+
|
| 115 |
+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
| 116 |
+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
| 117 |
+
|
| 118 |
+
def __del__(self):
|
| 119 |
+
del self.actor
|
| 120 |
+
del self.actor_target
|
| 121 |
+
del self.critic
|
| 122 |
+
del self.critic_target
|
| 123 |
+
return
|
| 124 |
+
|
README.md
CHANGED
|
@@ -1,14 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>Bayesian Policy Distillation</h1>
|
| 3 |
+
<h3>Towards Lightweight and Fast Neural Policy Networks</h3>
|
| 4 |
+
|
| 5 |
+
<a href="https://www.python.org/">
|
| 6 |
+
<img src="https://img.shields.io/badge/Python-3.7+-blue?logo=python&style=flat-square" alt="Python Badge"/>
|
| 7 |
+
</a>
|
| 8 |
+
|
| 9 |
+
<a href="https://pytorch.org/">
|
| 10 |
+
<img src="https://img.shields.io/badge/PyTorch-1.8+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
|
| 11 |
+
</a>
|
| 12 |
+
|
| 13 |
+
<a href="https://doi.org/10.1016/j.engappai.2025.113539">
|
| 14 |
+
<img src="https://img.shields.io/badge/EAAI%202026-Published-success?style=flat-square" alt="EAAI Badge"/>
|
| 15 |
+
</a>
|
| 16 |
+
|
| 17 |
+
<a href="https://www.elsevier.com/">
|
| 18 |
+
<img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
|
| 19 |
+
</a>
|
| 20 |
+
<br/><br/>
|
| 21 |
+
<img src="./gif_for_readme.gif" width="550px"/>
|
| 22 |
+
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
---
|
| 26 |
+
|
| 27 |
+
## Engineering Applications of Artificial Intelligence (EAAI 2026)
|
| 28 |
+
### PyTorch Implementation
|
| 29 |
+
|
| 30 |
+
This repository contains a PyTorch implementation of **Bayesian Policy Distillation (BPD)** from the paper:
|
| 31 |
+
|
| 32 |
+
> **Bayesian policy distillation: Towards lightweight and fast neural policy networks**
|
| 33 |
+
> Jangwon Kim, Yoonsu Jang, Jonghyeok Park, Yoonhee Gil, Soohee Han
|
| 34 |
+
> *Engineering Applications of Artificial Intelligence*, Volume 166, 2026
|
| 35 |
+
|
| 36 |
+
## 📄 Paper Link
|
| 37 |
+
> **DOI:** https://doi.org/10.1016/j.engappai.2025.113539
|
| 38 |
+
> **Journal:** Engineering Applications of Artificial Intelligence
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## Bayesian Policy Distillation
|
| 43 |
+
|
| 44 |
+
BPD achieves extreme policy compression through offline reinforcement learning by:
|
| 45 |
+
1. **Bayesian Neural Networks**: Uncertainty-driven dynamic weight pruning
|
| 46 |
+
2. **Sparse Variational Dropout**: Automatic sparsity induction via KL regularization
|
| 47 |
+
3. **Offline RL Framework**: Value optimization + behavior cloning
|
| 48 |
+
```math
|
| 49 |
+
\mathcal{L}_{BPD}(\theta, \alpha) = -\lambda Q_{\psi_1}(s, \pi_\omega(s)) + \frac{|\mathcal{D}|}{M}\sum_{m=1}^{M}(\pi_{\omega_m}(s_m) - a_m)^2 + \eta \cdot D_{KL}(q(\omega|\theta,\alpha) \| p(\omega))
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
**Key Results:**
|
| 53 |
+
- **~98% compression** (1.5-2.5% sparsity) while maintaining performance
|
| 54 |
+
- **4.5× faster inference** on embedded systems
|
| 55 |
+
- Successfully deployed on real inverted pendulum with **78% inference time reduction**
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Quick Start
|
| 60 |
+
|
| 61 |
+
### Basic Training
|
| 62 |
+
```bash
|
| 63 |
+
python main.py --env-name Hopper-v3 --level expert --random-seed 1
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Custom Configuration
|
| 67 |
+
```bash
|
| 68 |
+
python main.py \
|
| 69 |
+
--env-name Walker2d-v3 \
|
| 70 |
+
--level medium \
|
| 71 |
+
--student-hidden-dims "(128, 128)" \
|
| 72 |
+
--alpha-threshold 2 \
|
| 73 |
+
--nu 4 \
|
| 74 |
+
--h 0.5
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Available Environments
|
| 78 |
+
- `Hopper-v3`, `Walker2d-v3`, `HalfCheetah-v3`, `Ant-v3`
|
| 79 |
+
|
| 80 |
+
### Teacher Policy Levels
|
| 81 |
+
- `expert`: High-performance teacher policy
|
| 82 |
+
- `medium`: Moderate-performance teacher policy
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## Key Hyperparameters
|
| 87 |
+
|
| 88 |
+
| Parameter | Default | Description |
|
| 89 |
+
|-----------|---------|-------------|
|
| 90 |
+
| `--student-hidden-dims` | (128, 128) | Student network hidden layer sizes |
|
| 91 |
+
| `--alpha-threshold` | 2 | Pruning threshold for log(α) (higher = less compression) |
|
| 92 |
+
| `--nu` | 4 | KL weight annealing speed |
|
| 93 |
+
| `--h` | 0.5 | Q-value loss coefficient |
|
| 94 |
+
| `--batch-size` | 256 | Mini-batch size |
|
| 95 |
+
| `--max-teaching-count` | 1000000 | Total training iterations |
|
| 96 |
+
| `--eval-freq` | 5000 | Evaluation frequency |
|
| 97 |
+
|
| 98 |
+
**Adjusting Compression:**
|
| 99 |
+
- `--alpha-threshold 3-4`: Conservative pruning
|
| 100 |
+
- `--alpha-threshold 2`: Balanced [default]
|
| 101 |
+
- `--alpha-threshold 1`: Aggressive pruning
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Results
|
| 106 |
+
|
| 107 |
+
### MuJoCo Benchmark (Expert Teacher)
|
| 108 |
+
|
| 109 |
+
| Environment | Teacher | BPD (Ours) | Sparsity | Compression |
|
| 110 |
+
|------------|---------|------------|----------|-------------|
|
| 111 |
+
| Ant-v3 | 5364 | 5455 | 2.40% | **41.7×** |
|
| 112 |
+
| Walker2d-v3 | 5357 | 4817 | 1.68% | **59.5×** |
|
| 113 |
+
| Hopper-v3 | 3583 | 3134 | 1.35% | **74.1×** |
|
| 114 |
+
| HalfCheetah-v3 | 11432 | 10355 | 2.21% | **45.2×** |
|
| 115 |
+
|
| 116 |
+
### Real Hardware (Inverted Pendulum)
|
| 117 |
+
- **Inference**: 1.36ms → 0.30ms (**4.5× faster**)
|
| 118 |
+
- **Memory**: 290.82KB → 4.43KB (**98.5% reduction**)
|
| 119 |
+
- **Parameters**: 72,705 → 1,108 (**65.6× compression**)
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Citation
|
| 123 |
+
```bibtex
|
| 124 |
+
@article{kim2026bayesian,
|
| 125 |
+
title={Bayesian policy distillation: Towards lightweight and fast neural policy networks},
|
| 126 |
+
author={Kim, Jangwon and Jang, Yoonsu and Park, Jonghyeok and Gil, Yoonhee and Han, Soohee},
|
| 127 |
+
journal={Engineering Applications of Artificial Intelligence},
|
| 128 |
+
volume={166},
|
| 129 |
+
pages={113539},
|
| 130 |
+
year={2026},
|
| 131 |
+
publisher={Elsevier}
|
| 132 |
+
}
|
| 133 |
+
```
|
SGVLB.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SGVLB(nn.Module):
|
| 7 |
+
def __init__(self, net, train_size, loss_type='cross_entropy', device='cuda'):
|
| 8 |
+
super(SGVLB, self).__init__()
|
| 9 |
+
self.train_size = train_size
|
| 10 |
+
self.net = net
|
| 11 |
+
self.loss_type = loss_type
|
| 12 |
+
self.device = device
|
| 13 |
+
|
| 14 |
+
def forward(self, input, target, kl_weight=1.0):
|
| 15 |
+
assert not target.requires_grad
|
| 16 |
+
kl = torch.FloatTensor([0.0]).to(self.device)
|
| 17 |
+
for module in self.net.children():
|
| 18 |
+
if hasattr(module, 'kl_reg'):
|
| 19 |
+
kl = kl + module.kl_reg()
|
| 20 |
+
|
| 21 |
+
if self.loss_type == 'cross_entropy':
|
| 22 |
+
SGVLB = F.cross_entropy(input, target) * self.train_size + kl_weight * kl
|
| 23 |
+
elif self.loss_type in ['l2', 'L2']:
|
| 24 |
+
SGVLB = ((input - target) ** 2).mean() * self.train_size + kl_weight * kl
|
| 25 |
+
else:
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
return SGVLB
|
gif_for_readme.gif
ADDED
|
Git LFS Details
|
layer.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LinearSVDO(nn.Module):
|
| 10 |
+
def __init__(self, in_features, out_features, alpha_threshold, theta_threshold, device):
|
| 11 |
+
super(LinearSVDO, self).__init__()
|
| 12 |
+
self.in_features = in_features
|
| 13 |
+
self.out_features = out_features
|
| 14 |
+
self.alpha_threshold = alpha_threshold
|
| 15 |
+
self.theta_threshold = theta_threshold
|
| 16 |
+
self.device = device
|
| 17 |
+
|
| 18 |
+
self.W = Parameter(torch.Tensor(out_features, in_features))
|
| 19 |
+
self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
|
| 20 |
+
self.bias = Parameter(torch.Tensor(1, out_features))
|
| 21 |
+
|
| 22 |
+
self.reset_parameters()
|
| 23 |
+
|
| 24 |
+
def reset_parameters(self):
|
| 25 |
+
self.bias.data.zero_()
|
| 26 |
+
self.W.data.normal_(0, 0.02)
|
| 27 |
+
self.log_sigma.data.fill_(-5)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.W))
|
| 31 |
+
self.log_alpha = torch.clamp(self.log_alpha, -10, 10)
|
| 32 |
+
|
| 33 |
+
if self.training:
|
| 34 |
+
lrt_mean = F.linear(x, self.W) + self.bias
|
| 35 |
+
lrt_std = F.linear(torch.sqrt(x * x), torch.exp(2*self.log_sigma)+ 1e-8)
|
| 36 |
+
eps = torch.randn_like(lrt_std)
|
| 37 |
+
return lrt_mean + lrt_std * eps
|
| 38 |
+
|
| 39 |
+
out = self.W * (self.log_alpha < self.alpha_threshold).float()
|
| 40 |
+
out = F.linear(x, out) + self.bias
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
def get_pruned_weights(self):
|
| 44 |
+
W = self.W * (self.log_alpha < self.alpha_threshold).float()
|
| 45 |
+
return W
|
| 46 |
+
|
| 47 |
+
def get_num_remained_weights(self):
|
| 48 |
+
num = ((self.log_alpha < self.alpha_threshold) * (torch.abs(self.W) > self.theta_threshold)).sum().item()
|
| 49 |
+
return num
|
| 50 |
+
|
| 51 |
+
def kl_reg(self):
|
| 52 |
+
k1, k2, k3 = torch.FloatTensor([0.63576]).to(self.device), torch.FloatTensor([1.8732]).to(self.device), torch.FloatTensor([1.48695]).to(self.device)
|
| 53 |
+
KL = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
|
| 54 |
+
KL = - torch.sum(KL)
|
| 55 |
+
return KL
|
| 56 |
+
|
logger.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from tabulate import tabulate
|
| 8 |
+
from pandas import DataFrame
|
| 9 |
+
from time import gmtime, strftime
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Logger:
|
| 13 |
+
def __init__(self, env_info, fmt=None):
|
| 14 |
+
self.handler = True
|
| 15 |
+
self.scalar_metrics = OrderedDict()
|
| 16 |
+
self.fmt = fmt if fmt else dict()
|
| 17 |
+
|
| 18 |
+
base = './logs'
|
| 19 |
+
if not os.path.exists(base): os.mkdir(base)
|
| 20 |
+
self.path = '%s/%s-%s' % (base, env_info['name'], env_info['seed'])
|
| 21 |
+
|
| 22 |
+
self.logs = self.path + '.csv'
|
| 23 |
+
self.output = self.path + '.out'
|
| 24 |
+
self.checkpoint = self.path + '.cpt'
|
| 25 |
+
|
| 26 |
+
def prin(*args):
|
| 27 |
+
str_to_write = ' '.join(map(str, args))
|
| 28 |
+
with open(self.output, 'a') as f:
|
| 29 |
+
f.write(str_to_write + '\n')
|
| 30 |
+
f.flush()
|
| 31 |
+
|
| 32 |
+
print(str_to_write)
|
| 33 |
+
sys.stdout.flush()
|
| 34 |
+
|
| 35 |
+
self.print = prin
|
| 36 |
+
|
| 37 |
+
def add_scalar(self, t, key, value):
|
| 38 |
+
if key not in self.scalar_metrics:
|
| 39 |
+
self.scalar_metrics[key] = []
|
| 40 |
+
self.scalar_metrics[key] += [(t, value)]
|
| 41 |
+
|
| 42 |
+
def add_dict(self, t, d):
|
| 43 |
+
for key, value in d.iteritems():
|
| 44 |
+
self.add_scalar(t, key, value)
|
| 45 |
+
|
| 46 |
+
def add(self, t, **args):
|
| 47 |
+
for key, value in args.items():
|
| 48 |
+
self.add_scalar(t, key, value)
|
| 49 |
+
|
| 50 |
+
def iter_info(self, order=None):
|
| 51 |
+
names = list(self.scalar_metrics.keys())
|
| 52 |
+
if order:
|
| 53 |
+
names = order
|
| 54 |
+
values = [self.scalar_metrics[name][-1][1] for name in names]
|
| 55 |
+
t = int(np.max([self.scalar_metrics[name][-1][0] for name in names]))
|
| 56 |
+
fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.1f' for name in names]
|
| 57 |
+
|
| 58 |
+
if self.handler:
|
| 59 |
+
self.handler = False
|
| 60 |
+
self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt))
|
| 61 |
+
else:
|
| 62 |
+
self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1])
|
| 63 |
+
|
| 64 |
+
def save(self, silent=False):
|
| 65 |
+
result = None
|
| 66 |
+
for key in self.scalar_metrics.keys():
|
| 67 |
+
if result is None:
|
| 68 |
+
result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
|
| 69 |
+
else:
|
| 70 |
+
df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
|
| 71 |
+
result = result.join(df, how='outer')
|
| 72 |
+
result.to_csv(self.logs)
|
| 73 |
+
if not silent:
|
| 74 |
+
self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt')
|
main.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from BPD import BPDAgent
|
| 2 |
+
from utils import set_seed, get_learning_info, get_compression_ratio, load_buffer
|
| 3 |
+
import pickle
|
| 4 |
+
import inspect
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
import gdown
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
# Experiment
|
| 15 |
+
parser.add_argument("--env-name", default="Hopper-v3") # OpenAI gym environment name
|
| 16 |
+
parser.add_argument("--level", default="expert") # expert or medium
|
| 17 |
+
parser.add_argument("--random-seed", default=1, type=int)
|
| 18 |
+
parser.add_argument("--eval-freq", default=5000, type=int)
|
| 19 |
+
parser.add_argument("--max-teaching-count", default=1000000, type=int)
|
| 20 |
+
parser.add_argument('--num-test-epi', default=10, type=int)
|
| 21 |
+
parser.add_argument("--teacher-hidden-dims", default=(400, 300), type=tuple)
|
| 22 |
+
parser.add_argument("--student-hidden-dims", default=(128, 128), type=tuple)
|
| 23 |
+
|
| 24 |
+
parser.add_argument("--batch-size", default=256, type=int) # Batch size for both actor and critic
|
| 25 |
+
parser.add_argument("--discount", default=0.99) # Discount factor
|
| 26 |
+
parser.add_argument("--tau", default=0.005) # Target network update rate
|
| 27 |
+
parser.add_argument("--noise-clip", default=0.5) # Range to clip target policy noise
|
| 28 |
+
parser.add_argument("--policy-freq", default=2, type=int) # Frequency of delayed policy updates
|
| 29 |
+
|
| 30 |
+
parser.add_argument("--h", default=0.5, type=float)
|
| 31 |
+
parser.add_argument("--nu", default=4, type=float)
|
| 32 |
+
parser.add_argument("--theta-threshold", default=0, type=float)
|
| 33 |
+
parser.add_argument("--alpha-threshold", default=2, type=float)
|
| 34 |
+
parser.add_argument("--init-kl-weight", default=0, type=float)
|
| 35 |
+
parser.add_argument("--kl-max-coef", default=2, type=int)
|
| 36 |
+
parser.add_argument("--datasize", default=1000000, type=int)
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
# MuJoCo Environment Variable & Device Setting
|
| 41 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 42 |
+
|
| 43 |
+
# STEP 1: Make Instances & Variables
|
| 44 |
+
max_avg_return = 0
|
| 45 |
+
seed = set_seed(args.random_seed)
|
| 46 |
+
args.random_seed = seed
|
| 47 |
+
learning_info = get_learning_info(args, seed)
|
| 48 |
+
|
| 49 |
+
agent = BPDAgent(**learning_info)
|
| 50 |
+
kl_weight = args.init_kl_weight
|
| 51 |
+
|
| 52 |
+
# STEP 2: Load Dataset (=teacher buffer)
|
| 53 |
+
buffer = load_buffer(args.env_name, args.level, args.datasize)
|
| 54 |
+
|
| 55 |
+
# STEP 3: Training
|
| 56 |
+
print(f"Distilling Start! | env_name: {args.env_name} | level: {args.level} | seed: {seed}")
|
| 57 |
+
time_start = time.time()
|
| 58 |
+
return_list = []
|
| 59 |
+
for teaching_cnt in range(1, args.max_teaching_count + 1):
|
| 60 |
+
kl_weight = (args.nu / args.max_teaching_count) * teaching_cnt
|
| 61 |
+
kl_weight = min(kl_weight, args.kl_max_coef)
|
| 62 |
+
agent.set_kl_weight(kl_weight)
|
| 63 |
+
transitions = buffer.sample(batch_size=args.batch_size)
|
| 64 |
+
agent.train(transitions)
|
| 65 |
+
|
| 66 |
+
if teaching_cnt % args.eval_freq == 0:
|
| 67 |
+
avg_student_return, max_student_return, min_student_return = agent.test()
|
| 68 |
+
return_list.append(avg_student_return)
|
| 69 |
+
print(f"[INFO] Teaching Count: [{teaching_cnt}/{args.max_teaching_count}] | Average Student Return:"
|
| 70 |
+
f" {avg_student_return:.3f} | Max Student Return: {max_student_return:.3f} | Min Student Return:"
|
| 71 |
+
f" {min_student_return:.3f}", end='')
|
| 72 |
+
|
| 73 |
+
for i, c in enumerate(agent.actor.children()):
|
| 74 |
+
temp = (torch.abs(c.get_pruned_weights()) == 0).float().data.cpu().numpy().mean()
|
| 75 |
+
if hasattr(c, 'kl_reg'):
|
| 76 |
+
print(f" | sp_{i}: {1-temp:.3f}", end='')
|
| 77 |
+
del temp
|
| 78 |
+
print()
|
| 79 |
+
|
| 80 |
+
return_sum = 0
|
| 81 |
+
for i in range(10):
|
| 82 |
+
return_sum += return_list[-1 - i]
|
| 83 |
+
return_avg = return_sum / 10
|
| 84 |
+
|
| 85 |
+
time_end = time.time()
|
| 86 |
+
print(f"\nDistilling Finish! | Seed: {seed} | Consumed Time (sec): {time_end - time_start}")
|
| 87 |
+
print("Average Return of the Last 10 Episode: {}".format(return_avg))
|
| 88 |
+
cr = get_compression_ratio(learning_info["num_teacher_param"], agent)
|
| 89 |
+
print('Compression ratio (kep_w/all_w)=', cr)
|
| 90 |
+
print("-----------------------------------------------------------\n")
|
network.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from layer import LinearSVDO
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Define a simple 2 layer Network
|
| 8 |
+
class Net(nn.Module):
|
| 9 |
+
def __init__(self, state_dim, action_dim, action_bound, hidden_dims, alpha_threshold, theta_threshold, device):
|
| 10 |
+
super(Net, self).__init__()
|
| 11 |
+
self.fc1 = LinearSVDO(state_dim, hidden_dims[0], alpha_threshold, theta_threshold, device)
|
| 12 |
+
self.fc2 = LinearSVDO(hidden_dims[0], hidden_dims[1], alpha_threshold, theta_threshold, device)
|
| 13 |
+
self.fc3 = LinearSVDO(hidden_dims[1], action_dim, alpha_threshold, theta_threshold, device)
|
| 14 |
+
self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
|
| 15 |
+
self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
|
| 16 |
+
self.device = device
|
| 17 |
+
self.alpha_threshold = alpha_threshold
|
| 18 |
+
|
| 19 |
+
def _format(self, state):
|
| 20 |
+
x = state
|
| 21 |
+
if not isinstance(x, torch.Tensor):
|
| 22 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 23 |
+
x = x.unsqueeze(0)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self._format(x)
|
| 28 |
+
x = F.relu(self.fc1(x))
|
| 29 |
+
x = F.relu(self.fc2(x))
|
| 30 |
+
x = F.tanh(self.fc3(x))
|
| 31 |
+
x = x * self.action_rescale + self.action_rescale_bias
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Actor(nn.Module):
|
| 36 |
+
def __init__(self, state_dim, action_dim, student_hidden_dims, max_action):
|
| 37 |
+
super(Actor, self).__init__()
|
| 38 |
+
self.l1 = nn.Linear(state_dim, student_hidden_dims[0])
|
| 39 |
+
self.l2 = nn.Linear(student_hidden_dims[0], student_hidden_dims[1])
|
| 40 |
+
self.l3 = nn.Linear(student_hidden_dims[1], action_dim)
|
| 41 |
+
self.device = 'cuda'
|
| 42 |
+
|
| 43 |
+
self.max_action = max_action
|
| 44 |
+
|
| 45 |
+
def _format(self, state):
|
| 46 |
+
x = state
|
| 47 |
+
if not isinstance(x, torch.Tensor):
|
| 48 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 49 |
+
x = x.unsqueeze(0)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
def forward(self, state):
|
| 53 |
+
x = self._format(state)
|
| 54 |
+
a = F.relu(self.l1(x))
|
| 55 |
+
a = F.relu(self.l2(a))
|
| 56 |
+
return self.max_action * torch.tanh(self.l3(a))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Critic(nn.Module):
|
| 60 |
+
def __init__(self, state_dim, action_dim):
|
| 61 |
+
super(Critic, self).__init__()
|
| 62 |
+
|
| 63 |
+
self.device = 'cuda'
|
| 64 |
+
|
| 65 |
+
# Q1 architecture
|
| 66 |
+
self.l1 = nn.Linear(state_dim + action_dim, 256)
|
| 67 |
+
self.l2 = nn.Linear(256, 256)
|
| 68 |
+
self.l3 = nn.Linear(256, 1)
|
| 69 |
+
|
| 70 |
+
# Q2 architecture
|
| 71 |
+
self.l4 = nn.Linear(state_dim + action_dim, 256)
|
| 72 |
+
self.l5 = nn.Linear(256, 256)
|
| 73 |
+
self.l6 = nn.Linear(256, 1)
|
| 74 |
+
|
| 75 |
+
def _format(self, state, action):
|
| 76 |
+
x, u = state, action
|
| 77 |
+
if not isinstance(x, torch.Tensor):
|
| 78 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 79 |
+
x = x.unsqueeze(0)
|
| 80 |
+
|
| 81 |
+
if not isinstance(u, torch.Tensor):
|
| 82 |
+
u = torch.tensor(u, device=self.device, dtype=torch.float32)
|
| 83 |
+
u = u.unsqueeze(0)
|
| 84 |
+
|
| 85 |
+
return x, u
|
| 86 |
+
|
| 87 |
+
def forward(self, state, action):
|
| 88 |
+
x, u = self._format(state, action)
|
| 89 |
+
sa = torch.cat([x, u], 1)
|
| 90 |
+
|
| 91 |
+
q1 = F.relu(self.l1(sa))
|
| 92 |
+
q1 = F.relu(self.l2(q1))
|
| 93 |
+
q1 = self.l3(q1)
|
| 94 |
+
|
| 95 |
+
q2 = F.relu(self.l4(sa))
|
| 96 |
+
q2 = F.relu(self.l5(q2))
|
| 97 |
+
q2 = self.l6(q2)
|
| 98 |
+
return q1, q2
|
| 99 |
+
|
| 100 |
+
def Q1(self, state, action):
|
| 101 |
+
sa = torch.cat([state, action], 1)
|
| 102 |
+
|
| 103 |
+
q1 = F.relu(self.l1(sa))
|
| 104 |
+
q1 = F.relu(self.l2(q1))
|
| 105 |
+
q1 = self.l3(q1)
|
| 106 |
+
return q1
|
replay_memory.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ReplayMemory:
|
| 7 |
+
def __init__(self, state_dim, action_dim, device='cuda', capacity=5e6):
|
| 8 |
+
self.capacity = int(capacity)
|
| 9 |
+
self.size = 0
|
| 10 |
+
self.position = 0
|
| 11 |
+
|
| 12 |
+
self.state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
|
| 13 |
+
self.action_buffer = np.empty(shape=(self.capacity, action_dim), dtype=np.float32)
|
| 14 |
+
self.reward_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
|
| 15 |
+
self.next_state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
|
| 16 |
+
self.done_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
|
| 17 |
+
|
| 18 |
+
def normalize_states(self, eps=1e-3):
|
| 19 |
+
mean = np.mean(copy.deepcopy(self.state_buffer).astype('float64'), axis=0)
|
| 20 |
+
std = np.std(copy.deepcopy(self.state_buffer).astype('float64'), axis=0) + eps
|
| 21 |
+
self.state_buffer = (self.state_buffer.astype('float64') - mean) / std
|
| 22 |
+
self.next_state_buffer = (self.next_state_buffer.astype('float64') - mean) / std
|
| 23 |
+
|
| 24 |
+
self.state_buffer = self.state_buffer.astype('float32')
|
| 25 |
+
self.next_state_buffer = self.next_state_buffer.astype('float32')
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
def sample(self, batch_size):
|
| 29 |
+
idxs = np.random.randint(0, self.size, size=batch_size)
|
| 30 |
+
|
| 31 |
+
states = torch.FloatTensor(self.state_buffer[idxs]).to('cuda')
|
| 32 |
+
actions = torch.FloatTensor(self.action_buffer[idxs]).to('cuda')
|
| 33 |
+
rewards = torch.FloatTensor(self.reward_buffer[idxs]).to('cuda')
|
| 34 |
+
next_states = torch.FloatTensor(self.next_state_buffer[idxs]).to('cuda')
|
| 35 |
+
dones = torch.FloatTensor(self.done_buffer[idxs]).to('cuda')
|
| 36 |
+
|
| 37 |
+
return states, actions, rewards, next_states, dones
|
| 38 |
+
|
| 39 |
+
|
teacher_buffer/tmp
ADDED
|
File without changes
|
utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import os
|
| 6 |
+
import inspect
|
| 7 |
+
import pickle
|
| 8 |
+
import gdown
|
| 9 |
+
from network import Actor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def weight_init(m):
|
| 13 |
+
"""Custom weight init for Conv2D and Linear layers.
|
| 14 |
+
Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
|
| 15 |
+
|
| 16 |
+
if isinstance(m, nn.Linear):
|
| 17 |
+
nn.init.orthogonal_(m.weight.data)
|
| 18 |
+
m.bias.data.fill_(0.0)
|
| 19 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 20 |
+
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
| 21 |
+
assert m.weight.size(2) == m.weight.size(3)
|
| 22 |
+
m.weight.data.fill_(0.0)
|
| 23 |
+
m.bias.data.fill_(0.0)
|
| 24 |
+
mid = m.weight.size(2) // 2
|
| 25 |
+
gain = nn.init.calculate_gain('relu')
|
| 26 |
+
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def set_seed(random_seed):
|
| 30 |
+
if random_seed <= 0:
|
| 31 |
+
random_seed = np.random.randint(1, 9999)
|
| 32 |
+
else:
|
| 33 |
+
random_seed = random_seed
|
| 34 |
+
|
| 35 |
+
torch.manual_seed(random_seed)
|
| 36 |
+
np.random.seed(random_seed)
|
| 37 |
+
random.seed(random_seed)
|
| 38 |
+
|
| 39 |
+
return random_seed
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_env(env_name, seed):
|
| 43 |
+
import gymnasium as gym
|
| 44 |
+
# openai gym
|
| 45 |
+
env = gym.make(env_name)
|
| 46 |
+
env.action_space.seed(seed)
|
| 47 |
+
state_dim = env.observation_space.shape[0]
|
| 48 |
+
action_dim = env.action_space.shape[0]
|
| 49 |
+
action_bound = [env.action_space.low[0], env.action_space.high[0]]
|
| 50 |
+
|
| 51 |
+
env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}
|
| 52 |
+
|
| 53 |
+
return env, env_info
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_learning_info(args, seed):
|
| 57 |
+
env, env_info = make_env(args.env_name, seed)
|
| 58 |
+
device = 'cuda'
|
| 59 |
+
|
| 60 |
+
alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
|
| 61 |
+
'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold}
|
| 62 |
+
|
| 63 |
+
thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
|
| 64 |
+
max_action = 1
|
| 65 |
+
|
| 66 |
+
t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
|
| 67 |
+
num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())
|
| 68 |
+
|
| 69 |
+
kwargs = {
|
| 70 |
+
"env": env,
|
| 71 |
+
"args": args,
|
| 72 |
+
"env_info": env_info,
|
| 73 |
+
"thresholds": thresholds,
|
| 74 |
+
"discount": args.discount,
|
| 75 |
+
"datasize": args.datasize,
|
| 76 |
+
"tau": args.tau,
|
| 77 |
+
"device": device,
|
| 78 |
+
"num_teacher_param": num_teacher_param,
|
| 79 |
+
"noise_clip": args.noise_clip * max_action,
|
| 80 |
+
"policy_freq": args.policy_freq,
|
| 81 |
+
"h": args.h,
|
| 82 |
+
}
|
| 83 |
+
return kwargs
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_compression_ratio(num_teacher_param, agent):
|
| 87 |
+
kep_w = 0
|
| 88 |
+
for c in agent.actor.children():
|
| 89 |
+
kep_w += c.get_num_remained_weights()
|
| 90 |
+
#
|
| 91 |
+
|
| 92 |
+
return kep_w / num_teacher_param
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_buffer(env_name, level, datasize):
|
| 96 |
+
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
| 97 |
+
file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
|
| 98 |
+
try:
|
| 99 |
+
with open(file_path, "rb") as fr:
|
| 100 |
+
buffer = pickle.load(fr)
|
| 101 |
+
buffer.size = datasize
|
| 102 |
+
except FileNotFoundError:
|
| 103 |
+
# Download the file
|
| 104 |
+
if level == 'expert':
|
| 105 |
+
print("Downloading the teacher buffer...")
|
| 106 |
+
if env_name == "Ant-v3":
|
| 107 |
+
file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
|
| 108 |
+
elif env_name == "Walker2d-v3":
|
| 109 |
+
file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
|
| 110 |
+
elif env_name == "HalfCheetah-v3":
|
| 111 |
+
file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
|
| 112 |
+
elif env_name == "Hopper-v3":
|
| 113 |
+
file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError("Invalid Environment Name")
|
| 116 |
+
|
| 117 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
| 118 |
+
gdown.download(url, file_path, quiet=False)
|
| 119 |
+
print("Download Complete!")
|
| 120 |
+
elif level == 'medium':
|
| 121 |
+
if env_name == "Ant-v3":
|
| 122 |
+
file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
|
| 123 |
+
elif env_name == "Walker2d-v3":
|
| 124 |
+
file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
|
| 125 |
+
elif env_name == "HalfCheetah-v3":
|
| 126 |
+
file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
|
| 127 |
+
elif env_name == "Hopper-v3":
|
| 128 |
+
file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError("Invalid Environment Name")
|
| 131 |
+
|
| 132 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
| 133 |
+
gdown.download(url, file_path, quiet=False)
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError("Invalid Level. Choose from ['expert', 'medium']")
|
| 137 |
+
|
| 138 |
+
with open(file_path, "rb") as fr:
|
| 139 |
+
buffer = pickle.load(fr)
|
| 140 |
+
buffer.size = datasize
|
| 141 |
+
|
| 142 |
+
return buffer
|