Spaces:
Runtime error
Runtime error
| "Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks." | |
| from .basic_data import * | |
| from .torch_core import * | |
| import torch.distributed as dist | |
| __all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList', | |
| 'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly'] | |
| class OptimWrapper(): | |
| "Basic wrapper around `opt` to simplify hyper-parameters changes." | |
| def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True): | |
| assert not isinstance(opt, OptimWrapper) | |
| self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd | |
| self.opt_keys = list(self.opt.param_groups[0].keys()) | |
| self.opt_keys.remove('params') | |
| self.read_defaults() | |
| self.wd = wd | |
| def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0., | |
| true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer: | |
| "Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`." | |
| split_params = split_no_wd_params(layer_groups) | |
| opt = opt_func([{'params': p, 'lr':0} for p in split_params]) | |
| opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd) | |
| opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func | |
| return opt | |
| def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True): | |
| "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters." | |
| opt_func = getattr(self, 'opt_func', self.opt.__class__) | |
| res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd) | |
| res.mom,res.beta = self.mom,self.beta | |
| return res | |
| def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]): | |
| "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters." | |
| opt_func = getattr(self, 'opt_func', self.opt.__class__) | |
| opt = opt_func([{'params': p, 'lr':0} for p in param_groups]) | |
| opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd) | |
| opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta | |
| return opt | |
| def __repr__(self)->str: | |
| return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}' | |
| #Pytorch optimizer methods | |
| def step(self)->None: | |
| "Set weight decay and step optimizer." | |
| # weight decay outside of optimizer step (AdamW) | |
| if self.true_wd: | |
| for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]): | |
| for p in pg1['params']: p.data.mul_(1 - wd*lr) | |
| if self.bn_wd: | |
| for p in pg2['params']: p.data.mul_(1 - wd*lr) | |
| self.set_val('weight_decay', listify(0, self._wd)) | |
| self.opt.step() | |
| def zero_grad(self)->None: | |
| "Clear optimizer gradients." | |
| self.opt.zero_grad() | |
| #Passthrough to the inner opt. | |
| def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None) | |
| def __setstate__(self,data:Any): self.__dict__.update(data) | |
| def clear(self): | |
| "Reset the state of the inner optimizer." | |
| sd = self.state_dict() | |
| sd['state'] = {} | |
| self.load_state_dict(sd) | |
| def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups]) | |
| #Hyperparameters as properties | |
| def lr(self)->float: return self._lr[-1] | |
| def lr(self, val:float)->None: | |
| self._lr = self.set_val('lr', listify(val, self._lr)) | |
| def mom(self)->float:return self._mom[-1] | |
| def mom(self, val:float)->None: | |
| if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom)) | |
| elif 'betas' in self.opt_keys: self.set_val('betas', (listify(val, self._mom), self._beta)) | |
| self._mom = listify(val, self._mom) | |
| def beta(self)->float: return None if self._beta is None else self._beta[-1] | |
| def beta(self, val:float)->None: | |
| "Set beta (or alpha as makes sense for given optimizer)." | |
| if val is None: return | |
| if 'betas' in self.opt_keys: self.set_val('betas', (self._mom, listify(val, self._beta))) | |
| elif 'alpha' in self.opt_keys: self.set_val('alpha', listify(val, self._beta)) | |
| self._beta = listify(val, self._beta) | |
| def wd(self)->float: return self._wd[-1] | |
| def wd(self, val:float)->None: | |
| "Set weight decay." | |
| if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd) | |
| self._wd = listify(val, self._wd) | |
| #Helper functions | |
| def read_defaults(self)->None: | |
| "Read the values inside the optimizer for the hyper-parameters." | |
| self._beta = None | |
| if 'lr' in self.opt_keys: self._lr = self.read_val('lr') | |
| if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum') | |
| if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha') | |
| if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas') | |
| if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay') | |
| reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay'] | |
| stat_names = [n for n in self.opt_keys if n not in reserved_names] | |
| self._stats = {n:self.read_val(n) for n in stat_names} | |
| def get_stat(self, name:str)->float: | |
| if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name) | |
| else: return self._stats[name][-1] | |
| def set_stat(self, name:str, value:Union[float, Collection[float]])->None: | |
| if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value) | |
| else: | |
| val = listify(value, self._stats[name]) | |
| self.set_val(name, val) | |
| self._stats[name] = val | |
| def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any: | |
| "Set `val` inside the optimizer dictionary at `key`." | |
| if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)] | |
| for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]): | |
| pg1[key] = v | |
| if bn_groups: pg2[key] = v | |
| return val | |
| def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]: | |
| "Read a hyperparameter `key` in the optimizer dictionary." | |
| val = [pg[key] for pg in self.opt.param_groups[::2]] | |
| if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val] | |
| return val | |
| def get_state(self): | |
| "Return the inner state minus the layer groups." | |
| return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom, | |
| 'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd} | |
| def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]): | |
| res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'], | |
| bn_wd=state['bn_wd']) | |
| res._mom,res._beta = state['mom'],state['beta'] | |
| res.load_state_dict(state['opt_state']) | |
| return res | |
| class Callback(): | |
| "Base class for callbacks that want to record values, dynamically change learner params, etc." | |
| _order=0 | |
| def on_train_begin(self, **kwargs:Any)->None: | |
| "To initialize constants in the callback." | |
| pass | |
| def on_epoch_begin(self, **kwargs:Any)->None: | |
| "At the beginning of each epoch." | |
| pass | |
| def on_batch_begin(self, **kwargs:Any)->None: | |
| "Set HP before the output and loss are computed." | |
| pass | |
| def on_loss_begin(self, **kwargs:Any)->None: | |
| "Called after forward pass but before loss has been computed." | |
| pass | |
| def on_backward_begin(self, **kwargs:Any)->None: | |
| "Called after the forward pass and the loss has been computed, but before backprop." | |
| pass | |
| def on_backward_end(self, **kwargs:Any)->None: | |
| "Called after backprop but before optimizer step. Useful for true weight decay in AdamW." | |
| pass | |
| def on_step_end(self, **kwargs:Any)->None: | |
| "Called after the step of the optimizer but before the gradients are zeroed." | |
| pass | |
| def on_batch_end(self, **kwargs:Any)->None: | |
| "Called at the end of the batch." | |
| pass | |
| def on_epoch_end(self, **kwargs:Any)->None: | |
| "Called at the end of an epoch." | |
| pass | |
| def on_train_end(self, **kwargs:Any)->None: | |
| "Useful for cleaning up things and saving files/models." | |
| pass | |
| def jump_to_epoch(self, epoch)->None: | |
| "To resume training at `epoch` directly." | |
| pass | |
| def get_state(self, minimal:bool=True): | |
| "Return the inner state of the `Callback`, `minimal` or not." | |
| to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy() | |
| if minimal: to_remove += getattr(self, 'not_min', []).copy() | |
| return {k:v for k,v in self.__dict__.items() if k not in to_remove} | |
| def __repr__(self): | |
| attrs = func_args(self.__init__) | |
| to_remove = getattr(self, 'exclude', []) | |
| list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove] | |
| return '\n'.join(list_repr) | |
| class SmoothenValue(): | |
| "Create a smooth moving average for a value (loss, etc) using `beta`." | |
| def __init__(self, beta:float): | |
| self.beta,self.n,self.mov_avg = beta,0,0 | |
| def add_value(self, val:float)->None: | |
| "Add `val` to calculate updated smoothed value." | |
| self.n += 1 | |
| self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | |
| self.smooth = self.mov_avg / (1 - self.beta ** self.n) | |
| CallbackList = Collection[Callback] | |
| def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False} | |
| class CallbackHandler(): | |
| "Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`." | |
| callbacks:CallbackList=None | |
| metrics:CallbackList=None | |
| beta:float=0.98 | |
| def __post_init__(self)->None: | |
| "Initialize smoother and learning stats." | |
| self.callbacks = ifnone(self.callbacks, []) | |
| self.metrics = ifnone(self.metrics, []) | |
| self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics] | |
| self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0)) | |
| self.smoothener = SmoothenValue(self.beta) | |
| self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state() | |
| def _call_and_update(self, cb, cb_name, **kwargs)->None: | |
| "Call `cb_name` on `cb` and update the inner state." | |
| new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict()) | |
| for k,v in new.items(): | |
| if k not in self.state_dict: | |
| raise Exception(f"{k} isn't a valid key in the state of the callbacks.") | |
| else: self.state_dict[k] = v | |
| def __call__(self, cb_name, call_mets=True, **kwargs)->None: | |
| "Call through to all of the `CallbakHandler` functions." | |
| if call_mets: | |
| for met in self.metrics: self._call_and_update(met, cb_name, **kwargs) | |
| for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs) | |
| def set_dl(self, dl:DataLoader): | |
| "Set the current `dl` used." | |
| if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl) | |
| if isinstance(dl.dataset, Callback): | |
| self.callbacks.append(dl.dataset) | |
| self.cb_dl = dl.dataset | |
| def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None: | |
| "About to start learning." | |
| self.state_dict = _get_init_state() | |
| self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics)) | |
| names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics] | |
| self('train_begin', metrics_names=names) | |
| if self.state_dict['epoch'] != 0: | |
| self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch'] | |
| for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch']) | |
| def on_epoch_begin(self)->None: | |
| "Handle new epoch." | |
| self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False | |
| self('epoch_begin') | |
| def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]: | |
| "Handle new batch `xb`,`yb` in `train` or validation." | |
| self.state_dict.update(dict(last_input=xb, last_target=yb, train=train, | |
| stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False)) | |
| self('batch_begin', mets = not self.state_dict['train']) | |
| return self.state_dict['last_input'], self.state_dict['last_target'] | |
| def on_loss_begin(self, out:Tensor)->Any: | |
| "Handle start of loss calculation with model output `out`." | |
| self.state_dict['last_output'] = out | |
| self('loss_begin', call_mets=False) | |
| return self.state_dict['last_output'] | |
| def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]: | |
| "Handle gradient calculation on `loss`." | |
| self.smoothener.add_value(loss.detach().cpu()) | |
| self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth | |
| self('backward_begin', call_mets=False) | |
| return self.state_dict['last_loss'], self.state_dict['skip_bwd'] | |
| def on_backward_end(self)->Any: | |
| "Handle end of gradient calculation." | |
| self('backward_end', call_mets=False) | |
| return self.state_dict['skip_step'] | |
| def on_step_end(self)->Any: | |
| "Handle end of optimization step." | |
| self('step_end', call_mets=False) | |
| return self.state_dict['skip_zero'] | |
| def on_batch_end(self, loss:Tensor)->Any: | |
| "Handle end of processing one batch with `loss`." | |
| self.state_dict['last_loss'] = loss | |
| self('batch_end', call_mets = not self.state_dict['train']) | |
| if self.state_dict['train']: | |
| self.state_dict['iteration'] += 1 | |
| self.state_dict['num_batch'] += 1 | |
| return self.state_dict['stop_epoch'] | |
| def on_epoch_end(self, val_loss:Tensor)->bool: | |
| "Epoch is done, process `val_loss`." | |
| self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None] | |
| self('epoch_end', call_mets = val_loss is not None) | |
| self.state_dict['epoch'] += 1 | |
| return self.state_dict['stop_training'] | |
| def on_train_end(self, exception:Union[bool,Exception])->None: | |
| "Handle end of training, `exception` is an `Exception` or False if no exceptions during training." | |
| self('train_end', exception=exception) | |
| def skip_validate(self): return self.state_dict['skip_validate'] | |
| class AverageMetric(Callback): | |
| "Wrap a `func` in a callback for metrics computation." | |
| def __init__(self, func): | |
| # If func has a __name__ use this one else it should be a partial | |
| name = func.__name__ if hasattr(func, '__name__') else func.func.__name__ | |
| self.func, self.name = func, name | |
| self.world = num_distrib() | |
| def on_epoch_begin(self, **kwargs): | |
| "Set the inner value to 0." | |
| self.val, self.count = 0.,0 | |
| def on_batch_end(self, last_output, last_target, **kwargs): | |
| "Update metric computation with `last_output` and `last_target`." | |
| if not is_listy(last_target): last_target=[last_target] | |
| self.count += first_el(last_target).size(0) | |
| val = self.func(last_output, *last_target) | |
| if self.world: | |
| val = val.clone() | |
| dist.all_reduce(val, op=dist.ReduceOp.SUM) | |
| val /= self.world | |
| self.val += first_el(last_target).size(0) * val.detach().cpu() | |
| def on_epoch_end(self, last_metrics, **kwargs): | |
| "Set the final result in `last_metrics`." | |
| return add_metrics(last_metrics, self.val/self.count) | |
| def annealing_no(start:Number, end:Number, pct:float)->Number: | |
| "No annealing, always return `start`." | |
| return start | |
| def annealing_linear(start:Number, end:Number, pct:float)->Number: | |
| "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." | |
| return start + pct * (end-start) | |
| def annealing_exp(start:Number, end:Number, pct:float)->Number: | |
| "Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0." | |
| return start * (end/start) ** pct | |
| def annealing_cos(start:Number, end:Number, pct:float)->Number: | |
| "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." | |
| cos_out = np.cos(np.pi * pct) + 1 | |
| return end + (start-end)/2 * cos_out | |
| def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number: | |
| "Helper function for `anneal_poly`." | |
| return end + (start-end) * (1-pct)**degree | |
| def annealing_poly(degree:Number)->Number: | |
| "Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0." | |
| return functools.partial(do_annealing_poly, degree=degree) | |
| class Scheduler(): | |
| "Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`" | |
| def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None): | |
| self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0) | |
| self.n_iter = max(1,n_iter) | |
| if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no | |
| else: self.func = func | |
| self.n = 0 | |
| def restart(self): self.n = 0 | |
| def step(self)->Number: | |
| "Return next value along annealed schedule." | |
| self.n += 1 | |
| return self.func(self.start, self.end, self.n/self.n_iter) | |
| def is_done(self)->bool: | |
| "Return `True` if schedule completed." | |
| return self.n >= self.n_iter | |