Spaces:
Runtime error
Runtime error
| from .torch_core import * | |
| from torch.optim import Optimizer | |
| import types | |
| __all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer'] | |
| StatScope = Enum('StatScope', 'Global Group Layer Channel Weight') | |
| class Statistic(): | |
| name:str | |
| param:float=0.9 # e.g. for exp moving average | |
| scope:StatScope=StatScope.Weight | |
| init:float=0. # starting value | |
| def buf(self): return f'{self.name}_buffer' | |
| def new_step(self): | |
| "Set state when computing statistics for Global or Group" | |
| raise NotImplementedError | |
| def accumulate(self, val): | |
| "Add `val` to statistic" | |
| raise NotImplementedError | |
| def update(self, state, param, val=None, step=None): | |
| "Update state with accumlated, or `val` (if `Weight` or `Layer` scope)" | |
| raise NotImplementedError | |
| class ConstStatistic(Statistic): | |
| def buf(self): return None | |
| def new_step(self): pass | |
| def accumulate(self): pass | |
| def update(self, state, param, val=None, step=None): return param | |
| class CounterStat(Statistic): | |
| def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None | |
| def buf(self): return self._buf | |
| def new_step(self): pass | |
| def accumulate(self, val): pass | |
| def update(self, state, param, val=None, step=None): return state + 1 | |
| class AvgStatistic(Statistic): | |
| decay:bool=False | |
| debias:bool=False | |
| def new_step(self): self.val,self.count = 0.,0 | |
| def accumulate(self, val): | |
| self.count += 1 | |
| self.val += self._get_val1(val) | |
| def _get_val1(self, val): return val.mean() | |
| def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val) | |
| def _get_val3(self, state, val, param): | |
| v = val.view(val.size(0), -1).mean(1) | |
| return state.add_(1-param, v) if self.decay else state.add_(v) | |
| def update(self, state, param, val=None, step=None): | |
| if self.scope == StatScope.Weight: | |
| # `state` is a tensor | |
| res = self._get_val2(state.mul_(param), val, param) | |
| elif self.scope == StatScope.Channel: | |
| # `state` is a tensor of size n_channels | |
| res = self._get_val3(state.mul_(param), val, param) | |
| # For everything else, `state` is a scalar | |
| elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.) | |
| elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.) | |
| else: return state | |
| if self.debias and step is not None: res /= (1 - param ** step) | |
| return res | |
| class AvgSquare(AvgStatistic): | |
| def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False): | |
| super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias) | |
| def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel() | |
| def _get_val2(self, state, val, param): | |
| return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val) | |
| def _get_val3(self, state, val, param): | |
| v = val.view(val.size(0), -1).mean(1) | |
| return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v) | |
| class GeneralOptimizer(Optimizer): | |
| def __init__(self, params, stats=None, on_step:Callable=None): | |
| defaults = {s.name:s.param for s in listify(stats) if s.name is not None} | |
| super().__init__(params, defaults) | |
| self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats) | |
| self.init_stats() | |
| if on_step is not None: self.on_step = types.MethodType(on_step, self) | |
| def step(self, closure=None): | |
| self.update_stats() | |
| for i,pg in enumerate(self.param_groups): | |
| for p in pg['params']: | |
| if p.grad is not None: self.on_step(p, pg, i) | |
| def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data) | |
| def _split_stats(self, stats): | |
| splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope] | |
| for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope): | |
| if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s)) | |
| return splits | |
| def _init_stats(self, stats, data=None): | |
| return {stat.buf: stat.init if data is None | |
| else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None} | |
| def init_stats(self): | |
| self.state['global'] = self._init_stats(self.global_stats) | |
| for i,pg in enumerate(self.param_groups): | |
| self.state[f'group{i}'] = self._init_stats(self.group_stats) | |
| for p in pg['params']: | |
| self.state[p] = self._init_stats(self.layer_stats) | |
| self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1))) | |
| self.state[p].update(self._init_stats(self.weight_stats, p.data)) | |
| def _set_bufs(self, p, stats, pg, val=None): | |
| d = self.state[p] | |
| for stat in stats: | |
| if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None)) | |
| def update_stats(self): | |
| for stat in self.global_stats: stat.new_step() | |
| for i,pg in enumerate(self.param_groups): | |
| for stat in self.group_stats: stat.new_step() | |
| for p in pg['params']: | |
| if p.grad is not None: | |
| for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data) | |
| self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data) | |
| self._set_bufs(f'group{i}', self.group_stats, pg) | |
| self._set_bufs('global', self.global_stats, self.param_groups[0]) | |