ShaswatRobotics commited on
Commit
6db0f6c
·
verified ·
1 Parent(s): f7aba9e

Delete delta-iris/src/models/utils.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/models/utils.py +0 -198
delta-iris/src/models/utils.py DELETED
@@ -1,198 +0,0 @@
1
- from collections import OrderedDict
2
- import cv2
3
- from pathlib import Path
4
- import random
5
- import shutil
6
- from typing import Callable, Dict
7
-
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- from PIL import Image
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch.optim import AdamW
15
-
16
- from src.data import Episode
17
-
18
-
19
- def configure_optimizer(model: nn.Module, learning_rate: float, weight_decay: float, *blacklist_module_names) -> AdamW:
20
- """Credits to https://github.com/karpathy/minGPT"""
21
- # separate out all parameters to those that will and won't experience regularizing weight decay
22
- decay = set()
23
- no_decay = set()
24
- whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
25
- blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, nn.Conv2d, nn.GroupNorm)
26
- for mn, m in model.named_modules():
27
- for pn, p in m.named_parameters():
28
- fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
29
- if any([fpn.startswith(module_name) for module_name in blacklist_module_names]):
30
- no_decay.add(fpn)
31
- elif 'bias' in pn:
32
- # all biases will not be decayed
33
- no_decay.add(fpn)
34
- elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
35
- # weights of whitelist modules will be weight decayed
36
- decay.add(fpn)
37
- elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
38
- # weights of blacklist modules will NOT be weight decayed
39
- no_decay.add(fpn)
40
-
41
- # validate that we considered every parameter
42
- param_dict = {pn: p for pn, p in model.named_parameters()}
43
- inter_params = decay & no_decay
44
- union_params = decay | no_decay
45
- assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
46
- assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
47
-
48
- # create the pytorch optimizer object
49
- optim_groups = [
50
- {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
51
- {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
52
- ]
53
- optimizer = AdamW(optim_groups, lr=learning_rate)
54
-
55
- return optimizer
56
-
57
-
58
- def init_weights(module: nn.Module) -> None:
59
- if isinstance(module, (nn.Linear, nn.Embedding)):
60
- module.weight.data.normal_(mean=0.0, std=0.02)
61
- if isinstance(module, nn.Linear) and module.bias is not None:
62
- module.bias.data.zero_()
63
- elif isinstance(module, nn.LayerNorm):
64
- module.bias.data.zero_()
65
- module.weight.data.fill_(1.0)
66
-
67
-
68
- def extract_state_dict(state_dict: Dict, module_name: str) -> OrderedDict:
69
- return OrderedDict({k.split('.', 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)})
70
-
71
-
72
- def set_seed(seed: int) -> None:
73
- np.random.seed(seed)
74
- torch.manual_seed(seed)
75
- torch.cuda.manual_seed(seed)
76
- random.seed(seed)
77
-
78
-
79
- @torch.no_grad()
80
- def compute_discounted_returns(rewards: torch.FloatTensor, gamma: float) -> torch.FloatTensor:
81
- assert 0 < gamma <= 1 and rewards.ndim == 2 # (B, T)
82
- gammas = gamma ** torch.arange(rewards.size(1))
83
- r = rewards * gammas
84
-
85
- return (r + r.sum(dim=1, keepdim=True) - r.cumsum(dim=1)) / gammas
86
-
87
-
88
- class LossWithIntermediateLosses:
89
- def __init__(self, **kwargs) -> None:
90
- self.loss_total = sum(kwargs.values())
91
- self.intermediate_losses = {k: v.item() for k, v in kwargs.items()}
92
-
93
-
94
- class EpisodeDirManager:
95
- def __init__(self, episode_dir: Path, max_num_episodes: int) -> None:
96
- self.episode_dir = episode_dir
97
- self.episode_dir.mkdir(parents=False, exist_ok=True)
98
- self.max_num_episodes = max_num_episodes
99
- self.best_return = float('-inf')
100
-
101
- def save(self, episode: Episode, episode_id: int, epoch: int) -> None:
102
- if self.max_num_episodes is not None and self.max_num_episodes > 0:
103
- self._save(episode, episode_id, epoch)
104
-
105
- def _save(self, episode: Episode, episode_id: int, epoch: int) -> None:
106
- ep_paths = [p for p in self.episode_dir.iterdir() if p.stem.startswith('episode_')]
107
- assert len(ep_paths) <= self.max_num_episodes
108
- if len(ep_paths) == self.max_num_episodes:
109
- to_remove = min(ep_paths, key=lambda ep_path: int(ep_path.stem.split('_')[1]))
110
- to_remove.unlink()
111
- torch.save(episode.__dict__, self.episode_dir / f'episode_{episode_id}_epoch_{epoch}.pt')
112
-
113
- ep_return = episode.compute_metrics().episode_return
114
- if ep_return > self.best_return:
115
- self.best_return = ep_return
116
- path_best_ep = [p for p in self.episode_dir.iterdir() if p.stem.startswith('best_')]
117
- assert len(path_best_ep) in (0, 1)
118
- if len(path_best_ep) == 1:
119
- path_best_ep[0].unlink()
120
- torch.save(episode.__dict__, self.episode_dir / f'best_episode_{episode_id}_epoch_{epoch}.pt')
121
-
122
-
123
- class RandomHeuristic:
124
- def __init__(self, num_actions):
125
- self.num_actions = num_actions
126
-
127
- def act(self, obs):
128
- assert obs.ndim == 4 # (N, H, W, C)
129
- n = obs.size(0)
130
-
131
- return torch.randint(low=0, high=self.num_actions, size=(n,))
132
-
133
-
134
- def make_video(fname, fps, frames):
135
- assert frames.ndim == 4 # (T, H, W, C)
136
- _, h, w, c = frames.shape
137
- assert c == 3
138
-
139
- video = cv2.VideoWriter(str(fname), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
140
- for frame in frames:
141
- video.write(frame[:, :, ::-1])
142
- video.release()
143
-
144
-
145
- def try_until_no_except(fn: Callable):
146
- while True:
147
- try:
148
- fn()
149
- except:
150
- continue
151
- else:
152
- break
153
-
154
-
155
- def symlog(x: torch.Tensor) -> torch.Tensor:
156
- return torch.sign(x) * torch.log(torch.abs(x) + 1)
157
-
158
-
159
- def symexp(x: torch.Tensor) -> torch.Tensor:
160
- return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
161
-
162
-
163
- def two_hot(x: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
164
- x.clamp_(x_min, x_max - 1e-5)
165
- buckets = torch.linspace(x_min, x_max, num_buckets).to(x.device)
166
- k = torch.searchsorted(buckets, x) - 1
167
- values = torch.stack((buckets[k + 1] - x, x - buckets[k]), dim=-1) / (buckets[k + 1] - buckets[k]).unsqueeze(-1)
168
- two_hots = torch.scatter(x.new_zeros(*x.size(), num_buckets), dim=-1, index=torch.stack((k, k + 1), dim=-1), src=values)
169
-
170
- return two_hots
171
-
172
-
173
- def compute_softmax_over_buckets(logits: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
174
- buckets = torch.linspace(x_min, x_max, num_buckets).to(logits.device)
175
- probs = F.softmax(logits, dim=-1)
176
-
177
- return probs @ buckets
178
-
179
-
180
- def plot_counts(counts: np.ndarray) -> Image:
181
- fig, ax = plt.subplots(figsize=(14, 7))
182
- ax.plot(counts)
183
- p = Path('priorities.png')
184
- fig.savefig(p)
185
- plt.close(fig)
186
- im = Image.open(p)
187
- p.unlink()
188
-
189
- return im
190
-
191
-
192
- def compute_mask_after_first_done(ends: torch.LongTensor) -> torch.BoolTensor:
193
- assert ends.ndim == 2
194
- first_one_index = torch.argmax(ends, dim=1)
195
- mask = torch.arange(ends.size(1), device=ends.device).unsqueeze(0) <= first_one_index.unsqueeze(1)
196
- mask = torch.logical_or(mask, ends.sum(dim=1, keepdim=True) == 0)
197
-
198
- return mask