jangwon-kim-cocel commited on
Commit
96170c3
·
verified ·
1 Parent(s): e9da7fa

Upload 11 files

Browse files
Files changed (12) hide show
  1. .gitattributes +1 -0
  2. BPD.py +124 -0
  3. README.md +132 -13
  4. SGVLB.py +27 -0
  5. gif_for_readme.gif +3 -0
  6. layer.py +56 -0
  7. logger.py +74 -0
  8. main.py +90 -0
  9. network.py +106 -0
  10. replay_memory.py +39 -0
  11. teacher_buffer/tmp +0 -0
  12. utils.py +142 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gif_for_readme.gif filter=lfs diff=lfs merge=lfs -text
BPD.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from SGVLB import SGVLB
5
+ from network import Net, Critic
6
+
7
+
8
+ class BPDAgent(object):
9
+ def __init__(
10
+ self,
11
+ env,
12
+ args,
13
+ env_info,
14
+ thresholds,
15
+ datasize,
16
+ device,
17
+ discount,
18
+ tau,
19
+ noise_clip,
20
+ policy_freq,
21
+ h,
22
+ num_teacher_param,
23
+ ):
24
+ self.args = args
25
+ self.env = env
26
+ self.env_info = env_info
27
+
28
+ self.actor = Net(env_info['state_dim'], env_info['action_dim'], env_info['action_bound'],
29
+ args.student_hidden_dims, thresholds['ALPHA_THRESHOLD'], thresholds['THETA_THRESHOLD'],
30
+ device=device).to(device)
31
+ self.actor_target = copy.deepcopy(self.actor)
32
+ self.sgvlb = SGVLB(self.actor, datasize, loss_type='l2', device=device)
33
+ self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
34
+
35
+ self.critic = Critic(env_info['state_dim'], env_info['action_dim']).to(device)
36
+ self.critic_target = copy.deepcopy(self.critic)
37
+ self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
38
+
39
+ self.discount = discount
40
+ self.tau = tau
41
+ self.noise_clip = noise_clip
42
+ self.policy_freq = policy_freq
43
+ self.datasize = datasize
44
+ self.h = h
45
+
46
+ self.total_it = 0
47
+ self.kl_weight = 0
48
+
49
+ def set_kl_weight(self, kl_weight):
50
+ self.kl_weight = kl_weight
51
+ return
52
+
53
+ def test(self):
54
+ self.actor.eval()
55
+ with torch.no_grad():
56
+ return_list = []
57
+ for epi_cnt in range(1, self.args.num_test_epi):
58
+ episode_return = 0
59
+ done = False
60
+ state, _ = self.env.reset()
61
+ while not done:
62
+ action = self.actor(state)
63
+ action = action.cpu().numpy()[0]
64
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
65
+ done = terminated or truncated
66
+ episode_return += reward
67
+ state = next_state
68
+ return_list.append(episode_return)
69
+
70
+ avg_return = sum(return_list) / len(return_list)
71
+ max_return = max(return_list)
72
+ min_return = min(return_list)
73
+
74
+ return avg_return, max_return, min_return
75
+
76
+ def train(self, transition):
77
+ self.actor.train()
78
+
79
+ self.total_it += 1
80
+
81
+ states, actions, rewards, next_states, dones = transition
82
+
83
+ with torch.no_grad():
84
+ next_actions = (
85
+ self.actor_target(next_states)
86
+ ).clamp(self.env_info['action_bound'][0], self.env_info['action_bound'][1])
87
+
88
+ target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
89
+ target_Q = torch.min(target_Q1, target_Q2)
90
+ target_Q = rewards + (1 - dones) * self.discount * target_Q
91
+
92
+ current_Q1, current_Q2 = self.critic(states, actions)
93
+ critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
94
+
95
+ self.critic_optimizer.zero_grad()
96
+ critic_loss.backward()
97
+ self.critic_optimizer.step()
98
+
99
+ if self.total_it % self.policy_freq == 0:
100
+ pi = self.actor(states)
101
+ Q = self.critic.Q1(states, pi)
102
+ lmbda = (self.h * self.datasize) / Q.abs().mean().detach()
103
+
104
+ actor_loss = -lmbda * Q.mean() + self.sgvlb(pi, actions, self.kl_weight) # lambda = h*|D|/avg(|Q|)
105
+
106
+ # Optimize the actor
107
+ self.actor_optimizer.zero_grad()
108
+ actor_loss.backward()
109
+ self.actor_optimizer.step()
110
+
111
+ # Update the frozen target models
112
+ for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
113
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
114
+
115
+ for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
116
+ target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
117
+
118
+ def __del__(self):
119
+ del self.actor
120
+ del self.actor_target
121
+ del self.critic
122
+ del self.critic_target
123
+ return
124
+
README.md CHANGED
@@ -1,14 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- license: mit
3
- language:
4
- - en
5
- pipeline_tag: reinforcement-learning
6
- tags:
7
- - offline
8
- - policy
9
- - bayesian
10
- - distillation
11
- - offline-rl
12
- - rl
13
- - pruning
14
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Bayesian Policy Distillation</h1>
3
+ <h3>Towards Lightweight and Fast Neural Policy Networks</h3>
4
+
5
+ <a href="https://www.python.org/">
6
+ <img src="https://img.shields.io/badge/Python-3.7+-blue?logo=python&style=flat-square" alt="Python Badge"/>
7
+ </a>
8
+ &nbsp;&nbsp;
9
+ <a href="https://pytorch.org/">
10
+ <img src="https://img.shields.io/badge/PyTorch-1.8+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
11
+ </a>
12
+ &nbsp;&nbsp;
13
+ <a href="https://doi.org/10.1016/j.engappai.2025.113539">
14
+ <img src="https://img.shields.io/badge/EAAI%202026-Published-success?style=flat-square" alt="EAAI Badge"/>
15
+ </a>
16
+ &nbsp;&nbsp;
17
+ <a href="https://www.elsevier.com/">
18
+ <img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
19
+ </a>
20
+ <br/><br/>
21
+ <img src="./gif_for_readme.gif" width="550px"/>
22
+
23
+ </div>
24
+
25
  ---
26
+
27
+ ## Engineering Applications of Artificial Intelligence (EAAI 2026)
28
+ ### PyTorch Implementation
29
+
30
+ This repository contains a PyTorch implementation of **Bayesian Policy Distillation (BPD)** from the paper:
31
+
32
+ > **Bayesian policy distillation: Towards lightweight and fast neural policy networks**
33
+ > Jangwon Kim, Yoonsu Jang, Jonghyeok Park, Yoonhee Gil, Soohee Han
34
+ > *Engineering Applications of Artificial Intelligence*, Volume 166, 2026
35
+
36
+ ## 📄 Paper Link
37
+ > **DOI:** https://doi.org/10.1016/j.engappai.2025.113539
38
+ > **Journal:** Engineering Applications of Artificial Intelligence
39
+
40
+ ---
41
+
42
+ ## Bayesian Policy Distillation
43
+
44
+ BPD achieves extreme policy compression through offline reinforcement learning by:
45
+ 1. **Bayesian Neural Networks**: Uncertainty-driven dynamic weight pruning
46
+ 2. **Sparse Variational Dropout**: Automatic sparsity induction via KL regularization
47
+ 3. **Offline RL Framework**: Value optimization + behavior cloning
48
+ ```math
49
+ \mathcal{L}_{BPD}(\theta, \alpha) = -\lambda Q_{\psi_1}(s, \pi_\omega(s)) + \frac{|\mathcal{D}|}{M}\sum_{m=1}^{M}(\pi_{\omega_m}(s_m) - a_m)^2 + \eta \cdot D_{KL}(q(\omega|\theta,\alpha) \| p(\omega))
50
+ ```
51
+
52
+ **Key Results:**
53
+ - **~98% compression** (1.5-2.5% sparsity) while maintaining performance
54
+ - **4.5× faster inference** on embedded systems
55
+ - Successfully deployed on real inverted pendulum with **78% inference time reduction**
56
+
57
+ ---
58
+
59
+ ## Quick Start
60
+
61
+ ### Basic Training
62
+ ```bash
63
+ python main.py --env-name Hopper-v3 --level expert --random-seed 1
64
+ ```
65
+
66
+ ### Custom Configuration
67
+ ```bash
68
+ python main.py \
69
+ --env-name Walker2d-v3 \
70
+ --level medium \
71
+ --student-hidden-dims "(128, 128)" \
72
+ --alpha-threshold 2 \
73
+ --nu 4 \
74
+ --h 0.5
75
+ ```
76
+
77
+ ### Available Environments
78
+ - `Hopper-v3`, `Walker2d-v3`, `HalfCheetah-v3`, `Ant-v3`
79
+
80
+ ### Teacher Policy Levels
81
+ - `expert`: High-performance teacher policy
82
+ - `medium`: Moderate-performance teacher policy
83
+
84
+ ---
85
+
86
+ ## Key Hyperparameters
87
+
88
+ | Parameter | Default | Description |
89
+ |-----------|---------|-------------|
90
+ | `--student-hidden-dims` | (128, 128) | Student network hidden layer sizes |
91
+ | `--alpha-threshold` | 2 | Pruning threshold for log(α) (higher = less compression) |
92
+ | `--nu` | 4 | KL weight annealing speed |
93
+ | `--h` | 0.5 | Q-value loss coefficient |
94
+ | `--batch-size` | 256 | Mini-batch size |
95
+ | `--max-teaching-count` | 1000000 | Total training iterations |
96
+ | `--eval-freq` | 5000 | Evaluation frequency |
97
+
98
+ **Adjusting Compression:**
99
+ - `--alpha-threshold 3-4`: Conservative pruning
100
+ - `--alpha-threshold 2`: Balanced [default]
101
+ - `--alpha-threshold 1`: Aggressive pruning
102
+
103
+ ---
104
+
105
+ ## Results
106
+
107
+ ### MuJoCo Benchmark (Expert Teacher)
108
+
109
+ | Environment | Teacher | BPD (Ours) | Sparsity | Compression |
110
+ |------------|---------|------------|----------|-------------|
111
+ | Ant-v3 | 5364 | 5455 | 2.40% | **41.7×** |
112
+ | Walker2d-v3 | 5357 | 4817 | 1.68% | **59.5×** |
113
+ | Hopper-v3 | 3583 | 3134 | 1.35% | **74.1×** |
114
+ | HalfCheetah-v3 | 11432 | 10355 | 2.21% | **45.2×** |
115
+
116
+ ### Real Hardware (Inverted Pendulum)
117
+ - **Inference**: 1.36ms → 0.30ms (**4.5× faster**)
118
+ - **Memory**: 290.82KB → 4.43KB (**98.5% reduction**)
119
+ - **Parameters**: 72,705 → 1,108 (**65.6× compression**)
120
+ ---
121
+
122
+ ## Citation
123
+ ```bibtex
124
+ @article{kim2026bayesian,
125
+ title={Bayesian policy distillation: Towards lightweight and fast neural policy networks},
126
+ author={Kim, Jangwon and Jang, Yoonsu and Park, Jonghyeok and Gil, Yoonhee and Han, Soohee},
127
+ journal={Engineering Applications of Artificial Intelligence},
128
+ volume={166},
129
+ pages={113539},
130
+ year={2026},
131
+ publisher={Elsevier}
132
+ }
133
+ ```
SGVLB.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SGVLB(nn.Module):
7
+ def __init__(self, net, train_size, loss_type='cross_entropy', device='cuda'):
8
+ super(SGVLB, self).__init__()
9
+ self.train_size = train_size
10
+ self.net = net
11
+ self.loss_type = loss_type
12
+ self.device = device
13
+
14
+ def forward(self, input, target, kl_weight=1.0):
15
+ assert not target.requires_grad
16
+ kl = torch.FloatTensor([0.0]).to(self.device)
17
+ for module in self.net.children():
18
+ if hasattr(module, 'kl_reg'):
19
+ kl = kl + module.kl_reg()
20
+
21
+ if self.loss_type == 'cross_entropy':
22
+ SGVLB = F.cross_entropy(input, target) * self.train_size + kl_weight * kl
23
+ elif self.loss_type in ['l2', 'L2']:
24
+ SGVLB = ((input - target) ** 2).mean() * self.train_size + kl_weight * kl
25
+ else:
26
+ raise NotImplementedError
27
+ return SGVLB
gif_for_readme.gif ADDED

Git LFS Details

  • SHA256: 56905a7305067be4489c80f29ff08072ad62e78762069a4230c05c929f6f6c79
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
layer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class LinearSVDO(nn.Module):
10
+ def __init__(self, in_features, out_features, alpha_threshold, theta_threshold, device):
11
+ super(LinearSVDO, self).__init__()
12
+ self.in_features = in_features
13
+ self.out_features = out_features
14
+ self.alpha_threshold = alpha_threshold
15
+ self.theta_threshold = theta_threshold
16
+ self.device = device
17
+
18
+ self.W = Parameter(torch.Tensor(out_features, in_features))
19
+ self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
20
+ self.bias = Parameter(torch.Tensor(1, out_features))
21
+
22
+ self.reset_parameters()
23
+
24
+ def reset_parameters(self):
25
+ self.bias.data.zero_()
26
+ self.W.data.normal_(0, 0.02)
27
+ self.log_sigma.data.fill_(-5)
28
+
29
+ def forward(self, x):
30
+ self.log_alpha = self.log_sigma * 2.0 - 2.0 * torch.log(1e-16 + torch.abs(self.W))
31
+ self.log_alpha = torch.clamp(self.log_alpha, -10, 10)
32
+
33
+ if self.training:
34
+ lrt_mean = F.linear(x, self.W) + self.bias
35
+ lrt_std = F.linear(torch.sqrt(x * x), torch.exp(2*self.log_sigma)+ 1e-8)
36
+ eps = torch.randn_like(lrt_std)
37
+ return lrt_mean + lrt_std * eps
38
+
39
+ out = self.W * (self.log_alpha < self.alpha_threshold).float()
40
+ out = F.linear(x, out) + self.bias
41
+ return out
42
+
43
+ def get_pruned_weights(self):
44
+ W = self.W * (self.log_alpha < self.alpha_threshold).float()
45
+ return W
46
+
47
+ def get_num_remained_weights(self):
48
+ num = ((self.log_alpha < self.alpha_threshold) * (torch.abs(self.W) > self.theta_threshold)).sum().item()
49
+ return num
50
+
51
+ def kl_reg(self):
52
+ k1, k2, k3 = torch.FloatTensor([0.63576]).to(self.device), torch.FloatTensor([1.8732]).to(self.device), torch.FloatTensor([1.48695]).to(self.device)
53
+ KL = k1 * torch.sigmoid(k2 + k3 * self.log_alpha) - 0.5 * torch.log1p(torch.exp(-self.log_alpha))
54
+ KL = - torch.sum(KL)
55
+ return KL
56
+
logger.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import numpy as np
5
+
6
+ from collections import OrderedDict
7
+ from tabulate import tabulate
8
+ from pandas import DataFrame
9
+ from time import gmtime, strftime
10
+
11
+
12
+ class Logger:
13
+ def __init__(self, env_info, fmt=None):
14
+ self.handler = True
15
+ self.scalar_metrics = OrderedDict()
16
+ self.fmt = fmt if fmt else dict()
17
+
18
+ base = './logs'
19
+ if not os.path.exists(base): os.mkdir(base)
20
+ self.path = '%s/%s-%s' % (base, env_info['name'], env_info['seed'])
21
+
22
+ self.logs = self.path + '.csv'
23
+ self.output = self.path + '.out'
24
+ self.checkpoint = self.path + '.cpt'
25
+
26
+ def prin(*args):
27
+ str_to_write = ' '.join(map(str, args))
28
+ with open(self.output, 'a') as f:
29
+ f.write(str_to_write + '\n')
30
+ f.flush()
31
+
32
+ print(str_to_write)
33
+ sys.stdout.flush()
34
+
35
+ self.print = prin
36
+
37
+ def add_scalar(self, t, key, value):
38
+ if key not in self.scalar_metrics:
39
+ self.scalar_metrics[key] = []
40
+ self.scalar_metrics[key] += [(t, value)]
41
+
42
+ def add_dict(self, t, d):
43
+ for key, value in d.iteritems():
44
+ self.add_scalar(t, key, value)
45
+
46
+ def add(self, t, **args):
47
+ for key, value in args.items():
48
+ self.add_scalar(t, key, value)
49
+
50
+ def iter_info(self, order=None):
51
+ names = list(self.scalar_metrics.keys())
52
+ if order:
53
+ names = order
54
+ values = [self.scalar_metrics[name][-1][1] for name in names]
55
+ t = int(np.max([self.scalar_metrics[name][-1][0] for name in names]))
56
+ fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.1f' for name in names]
57
+
58
+ if self.handler:
59
+ self.handler = False
60
+ self.print(tabulate([[t] + values], ['epoch'] + names, floatfmt=fmt))
61
+ else:
62
+ self.print(tabulate([[t] + values], ['epoch'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1])
63
+
64
+ def save(self, silent=False):
65
+ result = None
66
+ for key in self.scalar_metrics.keys():
67
+ if result is None:
68
+ result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
69
+ else:
70
+ df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
71
+ result = result.join(df, how='outer')
72
+ result.to_csv(self.logs)
73
+ if not silent:
74
+ self.print('The log/output/model have been saved to: ' + self.path + ' + .csv/.out/.cpt')
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from BPD import BPDAgent
2
+ from utils import set_seed, get_learning_info, get_compression_ratio, load_buffer
3
+ import pickle
4
+ import inspect
5
+ import os
6
+ import argparse
7
+ import gdown
8
+ import time
9
+ import torch
10
+
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ # Experiment
15
+ parser.add_argument("--env-name", default="Hopper-v3") # OpenAI gym environment name
16
+ parser.add_argument("--level", default="expert") # expert or medium
17
+ parser.add_argument("--random-seed", default=1, type=int)
18
+ parser.add_argument("--eval-freq", default=5000, type=int)
19
+ parser.add_argument("--max-teaching-count", default=1000000, type=int)
20
+ parser.add_argument('--num-test-epi', default=10, type=int)
21
+ parser.add_argument("--teacher-hidden-dims", default=(400, 300), type=tuple)
22
+ parser.add_argument("--student-hidden-dims", default=(128, 128), type=tuple)
23
+
24
+ parser.add_argument("--batch-size", default=256, type=int) # Batch size for both actor and critic
25
+ parser.add_argument("--discount", default=0.99) # Discount factor
26
+ parser.add_argument("--tau", default=0.005) # Target network update rate
27
+ parser.add_argument("--noise-clip", default=0.5) # Range to clip target policy noise
28
+ parser.add_argument("--policy-freq", default=2, type=int) # Frequency of delayed policy updates
29
+
30
+ parser.add_argument("--h", default=0.5, type=float)
31
+ parser.add_argument("--nu", default=4, type=float)
32
+ parser.add_argument("--theta-threshold", default=0, type=float)
33
+ parser.add_argument("--alpha-threshold", default=2, type=float)
34
+ parser.add_argument("--init-kl-weight", default=0, type=float)
35
+ parser.add_argument("--kl-max-coef", default=2, type=int)
36
+ parser.add_argument("--datasize", default=1000000, type=int)
37
+
38
+ args = parser.parse_args()
39
+
40
+ # MuJoCo Environment Variable & Device Setting
41
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+
43
+ # STEP 1: Make Instances & Variables
44
+ max_avg_return = 0
45
+ seed = set_seed(args.random_seed)
46
+ args.random_seed = seed
47
+ learning_info = get_learning_info(args, seed)
48
+
49
+ agent = BPDAgent(**learning_info)
50
+ kl_weight = args.init_kl_weight
51
+
52
+ # STEP 2: Load Dataset (=teacher buffer)
53
+ buffer = load_buffer(args.env_name, args.level, args.datasize)
54
+
55
+ # STEP 3: Training
56
+ print(f"Distilling Start! | env_name: {args.env_name} | level: {args.level} | seed: {seed}")
57
+ time_start = time.time()
58
+ return_list = []
59
+ for teaching_cnt in range(1, args.max_teaching_count + 1):
60
+ kl_weight = (args.nu / args.max_teaching_count) * teaching_cnt
61
+ kl_weight = min(kl_weight, args.kl_max_coef)
62
+ agent.set_kl_weight(kl_weight)
63
+ transitions = buffer.sample(batch_size=args.batch_size)
64
+ agent.train(transitions)
65
+
66
+ if teaching_cnt % args.eval_freq == 0:
67
+ avg_student_return, max_student_return, min_student_return = agent.test()
68
+ return_list.append(avg_student_return)
69
+ print(f"[INFO] Teaching Count: [{teaching_cnt}/{args.max_teaching_count}] | Average Student Return:"
70
+ f" {avg_student_return:.3f} | Max Student Return: {max_student_return:.3f} | Min Student Return:"
71
+ f" {min_student_return:.3f}", end='')
72
+
73
+ for i, c in enumerate(agent.actor.children()):
74
+ temp = (torch.abs(c.get_pruned_weights()) == 0).float().data.cpu().numpy().mean()
75
+ if hasattr(c, 'kl_reg'):
76
+ print(f" | sp_{i}: {1-temp:.3f}", end='')
77
+ del temp
78
+ print()
79
+
80
+ return_sum = 0
81
+ for i in range(10):
82
+ return_sum += return_list[-1 - i]
83
+ return_avg = return_sum / 10
84
+
85
+ time_end = time.time()
86
+ print(f"\nDistilling Finish! | Seed: {seed} | Consumed Time (sec): {time_end - time_start}")
87
+ print("Average Return of the Last 10 Episode: {}".format(return_avg))
88
+ cr = get_compression_ratio(learning_info["num_teacher_param"], agent)
89
+ print('Compression ratio (kep_w/all_w)=', cr)
90
+ print("-----------------------------------------------------------\n")
network.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from layer import LinearSVDO
5
+
6
+
7
+ # Define a simple 2 layer Network
8
+ class Net(nn.Module):
9
+ def __init__(self, state_dim, action_dim, action_bound, hidden_dims, alpha_threshold, theta_threshold, device):
10
+ super(Net, self).__init__()
11
+ self.fc1 = LinearSVDO(state_dim, hidden_dims[0], alpha_threshold, theta_threshold, device)
12
+ self.fc2 = LinearSVDO(hidden_dims[0], hidden_dims[1], alpha_threshold, theta_threshold, device)
13
+ self.fc3 = LinearSVDO(hidden_dims[1], action_dim, alpha_threshold, theta_threshold, device)
14
+ self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
15
+ self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
16
+ self.device = device
17
+ self.alpha_threshold = alpha_threshold
18
+
19
+ def _format(self, state):
20
+ x = state
21
+ if not isinstance(x, torch.Tensor):
22
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
23
+ x = x.unsqueeze(0)
24
+ return x
25
+
26
+ def forward(self, x):
27
+ x = self._format(x)
28
+ x = F.relu(self.fc1(x))
29
+ x = F.relu(self.fc2(x))
30
+ x = F.tanh(self.fc3(x))
31
+ x = x * self.action_rescale + self.action_rescale_bias
32
+ return x
33
+
34
+
35
+ class Actor(nn.Module):
36
+ def __init__(self, state_dim, action_dim, student_hidden_dims, max_action):
37
+ super(Actor, self).__init__()
38
+ self.l1 = nn.Linear(state_dim, student_hidden_dims[0])
39
+ self.l2 = nn.Linear(student_hidden_dims[0], student_hidden_dims[1])
40
+ self.l3 = nn.Linear(student_hidden_dims[1], action_dim)
41
+ self.device = 'cuda'
42
+
43
+ self.max_action = max_action
44
+
45
+ def _format(self, state):
46
+ x = state
47
+ if not isinstance(x, torch.Tensor):
48
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
49
+ x = x.unsqueeze(0)
50
+ return x
51
+
52
+ def forward(self, state):
53
+ x = self._format(state)
54
+ a = F.relu(self.l1(x))
55
+ a = F.relu(self.l2(a))
56
+ return self.max_action * torch.tanh(self.l3(a))
57
+
58
+
59
+ class Critic(nn.Module):
60
+ def __init__(self, state_dim, action_dim):
61
+ super(Critic, self).__init__()
62
+
63
+ self.device = 'cuda'
64
+
65
+ # Q1 architecture
66
+ self.l1 = nn.Linear(state_dim + action_dim, 256)
67
+ self.l2 = nn.Linear(256, 256)
68
+ self.l3 = nn.Linear(256, 1)
69
+
70
+ # Q2 architecture
71
+ self.l4 = nn.Linear(state_dim + action_dim, 256)
72
+ self.l5 = nn.Linear(256, 256)
73
+ self.l6 = nn.Linear(256, 1)
74
+
75
+ def _format(self, state, action):
76
+ x, u = state, action
77
+ if not isinstance(x, torch.Tensor):
78
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
79
+ x = x.unsqueeze(0)
80
+
81
+ if not isinstance(u, torch.Tensor):
82
+ u = torch.tensor(u, device=self.device, dtype=torch.float32)
83
+ u = u.unsqueeze(0)
84
+
85
+ return x, u
86
+
87
+ def forward(self, state, action):
88
+ x, u = self._format(state, action)
89
+ sa = torch.cat([x, u], 1)
90
+
91
+ q1 = F.relu(self.l1(sa))
92
+ q1 = F.relu(self.l2(q1))
93
+ q1 = self.l3(q1)
94
+
95
+ q2 = F.relu(self.l4(sa))
96
+ q2 = F.relu(self.l5(q2))
97
+ q2 = self.l6(q2)
98
+ return q1, q2
99
+
100
+ def Q1(self, state, action):
101
+ sa = torch.cat([state, action], 1)
102
+
103
+ q1 = F.relu(self.l1(sa))
104
+ q1 = F.relu(self.l2(q1))
105
+ q1 = self.l3(q1)
106
+ return q1
replay_memory.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import copy
4
+
5
+
6
+ class ReplayMemory:
7
+ def __init__(self, state_dim, action_dim, device='cuda', capacity=5e6):
8
+ self.capacity = int(capacity)
9
+ self.size = 0
10
+ self.position = 0
11
+
12
+ self.state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
13
+ self.action_buffer = np.empty(shape=(self.capacity, action_dim), dtype=np.float32)
14
+ self.reward_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
15
+ self.next_state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
16
+ self.done_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
17
+
18
+ def normalize_states(self, eps=1e-3):
19
+ mean = np.mean(copy.deepcopy(self.state_buffer).astype('float64'), axis=0)
20
+ std = np.std(copy.deepcopy(self.state_buffer).astype('float64'), axis=0) + eps
21
+ self.state_buffer = (self.state_buffer.astype('float64') - mean) / std
22
+ self.next_state_buffer = (self.next_state_buffer.astype('float64') - mean) / std
23
+
24
+ self.state_buffer = self.state_buffer.astype('float32')
25
+ self.next_state_buffer = self.next_state_buffer.astype('float32')
26
+ return
27
+
28
+ def sample(self, batch_size):
29
+ idxs = np.random.randint(0, self.size, size=batch_size)
30
+
31
+ states = torch.FloatTensor(self.state_buffer[idxs]).to('cuda')
32
+ actions = torch.FloatTensor(self.action_buffer[idxs]).to('cuda')
33
+ rewards = torch.FloatTensor(self.reward_buffer[idxs]).to('cuda')
34
+ next_states = torch.FloatTensor(self.next_state_buffer[idxs]).to('cuda')
35
+ dones = torch.FloatTensor(self.done_buffer[idxs]).to('cuda')
36
+
37
+ return states, actions, rewards, next_states, dones
38
+
39
+
teacher_buffer/tmp ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import os
6
+ import inspect
7
+ import pickle
8
+ import gdown
9
+ from network import Actor
10
+
11
+
12
+ def weight_init(m):
13
+ """Custom weight init for Conv2D and Linear layers.
14
+ Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
15
+
16
+ if isinstance(m, nn.Linear):
17
+ nn.init.orthogonal_(m.weight.data)
18
+ m.bias.data.fill_(0.0)
19
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
20
+ # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
21
+ assert m.weight.size(2) == m.weight.size(3)
22
+ m.weight.data.fill_(0.0)
23
+ m.bias.data.fill_(0.0)
24
+ mid = m.weight.size(2) // 2
25
+ gain = nn.init.calculate_gain('relu')
26
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
27
+
28
+
29
+ def set_seed(random_seed):
30
+ if random_seed <= 0:
31
+ random_seed = np.random.randint(1, 9999)
32
+ else:
33
+ random_seed = random_seed
34
+
35
+ torch.manual_seed(random_seed)
36
+ np.random.seed(random_seed)
37
+ random.seed(random_seed)
38
+
39
+ return random_seed
40
+
41
+
42
+ def make_env(env_name, seed):
43
+ import gymnasium as gym
44
+ # openai gym
45
+ env = gym.make(env_name)
46
+ env.action_space.seed(seed)
47
+ state_dim = env.observation_space.shape[0]
48
+ action_dim = env.action_space.shape[0]
49
+ action_bound = [env.action_space.low[0], env.action_space.high[0]]
50
+
51
+ env_info = {'name': env_name, 'state_dim': state_dim, 'action_dim': action_dim, 'action_bound': action_bound, 'seed': seed}
52
+
53
+ return env, env_info
54
+
55
+
56
+ def get_learning_info(args, seed):
57
+ env, env_info = make_env(args.env_name, seed)
58
+ device = 'cuda'
59
+
60
+ alpha_dict = {'HalfCheetah-v3': args.alpha_threshold, 'Walker2d-v3': args.alpha_threshold,
61
+ 'Ant-v3': args.alpha_threshold, 'Hopper-v3': args.alpha_threshold}
62
+
63
+ thresholds = {"ALPHA_THRESHOLD": alpha_dict[args.env_name], "THETA_THRESHOLD": args.theta_threshold}
64
+ max_action = 1
65
+
66
+ t_p = Actor(env_info['state_dim'], env_info['action_dim'], (400, 300), 1)
67
+ num_teacher_param = sum(p2.numel() for p2 in t_p.parameters())
68
+
69
+ kwargs = {
70
+ "env": env,
71
+ "args": args,
72
+ "env_info": env_info,
73
+ "thresholds": thresholds,
74
+ "discount": args.discount,
75
+ "datasize": args.datasize,
76
+ "tau": args.tau,
77
+ "device": device,
78
+ "num_teacher_param": num_teacher_param,
79
+ "noise_clip": args.noise_clip * max_action,
80
+ "policy_freq": args.policy_freq,
81
+ "h": args.h,
82
+ }
83
+ return kwargs
84
+
85
+
86
+ def get_compression_ratio(num_teacher_param, agent):
87
+ kep_w = 0
88
+ for c in agent.actor.children():
89
+ kep_w += c.get_num_remained_weights()
90
+ #
91
+
92
+ return kep_w / num_teacher_param
93
+
94
+
95
+ def load_buffer(env_name, level, datasize):
96
+ current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
97
+ file_path = os.path.join(current_dir, "teacher_buffer", "[" + level + "_buffer]_" + env_name + ".pickle")
98
+ try:
99
+ with open(file_path, "rb") as fr:
100
+ buffer = pickle.load(fr)
101
+ buffer.size = datasize
102
+ except FileNotFoundError:
103
+ # Download the file
104
+ if level == 'expert':
105
+ print("Downloading the teacher buffer...")
106
+ if env_name == "Ant-v3":
107
+ file_id = "10VBf3bM38bNw9WsniQvirpNjRFWp8HZO"
108
+ elif env_name == "Walker2d-v3":
109
+ file_id = "1ungLoqNKS4NIldZ9H2mswwGh-3Ipgy0D"
110
+ elif env_name == "HalfCheetah-v3":
111
+ file_id = "1wO0HwDi1GNf9d9SrDJrf9x8XMZDOTkzl"
112
+ elif env_name == "Hopper-v3":
113
+ file_id ="10pqCliJSM_Iyb05dxHZfYs9VlmCmPryE"
114
+ else:
115
+ raise ValueError("Invalid Environment Name")
116
+
117
+ url = f"https://drive.google.com/uc?id={file_id}"
118
+ gdown.download(url, file_path, quiet=False)
119
+ print("Download Complete!")
120
+ elif level == 'medium':
121
+ if env_name == "Ant-v3":
122
+ file_id = "1-SKleNu6l-tY2awkx3tgVDUKbjkOaj_D"
123
+ elif env_name == "Walker2d-v3":
124
+ file_id = "1x6nkBBSWMRb3bENxUzcntHT1WlSNJmoh"
125
+ elif env_name == "HalfCheetah-v3":
126
+ file_id = "1OHkB6yVK3QcqbuJH0B_iNW_2cBnv96mR"
127
+ elif env_name == "Hopper-v3":
128
+ file_id ="1uqH2pgKKrhadsCXCwQWrvDvZ4ZyYFkM-"
129
+ else:
130
+ raise ValueError("Invalid Environment Name")
131
+
132
+ url = f"https://drive.google.com/uc?id={file_id}"
133
+ gdown.download(url, file_path, quiet=False)
134
+
135
+ else:
136
+ raise ValueError("Invalid Level. Choose from ['expert', 'medium']")
137
+
138
+ with open(file_path, "rb") as fr:
139
+ buffer = pickle.load(fr)
140
+ buffer.size = datasize
141
+
142
+ return buffer