Spaces:
Runtime error
Runtime error
| # code inspired by Fastai "Practical Deep Learning Part 2" Learner | |
| import math | |
| import os | |
| from functools import partial | |
| from operator import attrgetter | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import wandb | |
| class CancelFitException(Exception): | |
| pass | |
| class CancelBatchException(Exception): | |
| pass | |
| class CancelEpochException(Exception): | |
| pass | |
| class Callback: | |
| order = 0 | |
| class with_cbs: | |
| """Decorator that wraps function and calls certain callbacks before/after that function.""" | |
| def __init__(self, nm): | |
| self.nm = nm | |
| def __call__(self, f): | |
| def _f(o, *args, **kwargs): | |
| try: | |
| o.callback(f"before_{self.nm}") | |
| f(o, *args, **kwargs) | |
| o.callback(f"after_{self.nm}") | |
| except globals()[f"Cancel{self.nm.title()}Exception"]: | |
| pass | |
| finally: | |
| o.callback(f"cleanup_{self.nm}") | |
| return _f | |
| def run_cbs(cbs, method_nm, trainer=None): | |
| for cb in sorted(cbs, key=attrgetter("order")): # sort callbacks by 'order' | |
| method = getattr( | |
| cb, method_nm, None | |
| ) # get method from callback e.g. `before_batch` | |
| if method is not None: | |
| method(trainer) # if callback has such method then call it | |
| class Trainer: | |
| """Trainer with callbacks""" | |
| def __init__( | |
| self, | |
| model, | |
| dls=(0,), | |
| loss_func=F.mse_loss, | |
| opt_func=torch.optim.SGD, | |
| lr=0.1, | |
| cbs=[], | |
| n_inp=1, | |
| ): | |
| self.model = model | |
| self.dls = dls | |
| self.loss_func = loss_func | |
| self.opt_func = opt_func | |
| self.lr = lr | |
| self.cbs = cbs | |
| self.n_inp = n_inp | |
| def _one_batch(self): | |
| self.predict() | |
| self.callback("after_predict") | |
| self.get_loss() | |
| self.callback("after_loss") | |
| if self.training: | |
| self.backward() | |
| self.callback("after_backward") | |
| self.step() | |
| self.callback("after_step") | |
| self.zero_grad() | |
| def _one_epoch(self): | |
| for self.iter, self.batch in enumerate(self.dl): | |
| self._one_batch() | |
| def one_epoch(self, training): | |
| self.model.train(training) | |
| self.dl = self.dls.train if training else self.dls.valid | |
| self._one_epoch() | |
| def _fit(self, train, valid): | |
| for epoch in range(self.n_epochs): | |
| if train: | |
| self.one_epoch(True) | |
| if valid: | |
| torch.no_grad()(self.one_epoch)(False) | |
| def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None): | |
| self.n_epochs = n_epochs | |
| if lr is not None: | |
| self.lr = lr | |
| self.opt = self.opt_func(self.model.parameters(), self.lr) | |
| self._fit(train, valid) | |
| def callback(self, method_nm): | |
| run_cbs(self.cbs, method_nm, self) | |
| def predict(self, x=None): | |
| if x is not None: | |
| return self.model(x) | |
| self.preds = self.model(*self.batch[: self.n_inp]) | |
| def get_loss(self): | |
| self.loss = self.loss_func(self.preds, *self.batch[self.n_inp :]) | |
| def backward(self): | |
| self.loss.backward() | |
| def step(self): | |
| self.opt.step() | |
| def zero_grad(self): | |
| self.opt.zero_grad() | |
| def training(self): | |
| return self.model.training | |
| class ProgressCB(Callback): | |
| """Adds progress bar to Trainer and plotting loss curves after training.""" | |
| def __init__(self, in_notebook=False): | |
| super().__init__() | |
| self.train_loss = [] | |
| self.valid_loss = [] | |
| self.in_notebook = in_notebook | |
| def before_fit(self, trainer): | |
| if self.in_notebook: | |
| from tqdm.notebook import tqdm | |
| else: | |
| from tqdm import tqdm | |
| self.pbar = tqdm(total=trainer.n_epochs) | |
| def after_epoch(self, trainer): | |
| if trainer.training: | |
| self.pbar.update(1) | |
| def after_loss(self, trainer): | |
| if trainer.training: | |
| self.train_loss.append(trainer.loss.item()) | |
| tmp_train_loss = ( | |
| np.mean(self.train_loss[-10:]) if len(self.train_loss) > 10 else 0 | |
| ) | |
| tmp_valid_loss = ( | |
| np.mean(self.valid_loss[-len(trainer.dls.valid) :]) | |
| if len(self.valid_loss) > 0 | |
| else 0 | |
| ) | |
| self.pbar.set_description( | |
| f"train loss: {tmp_train_loss:.3f} | valid loss: {tmp_valid_loss:.3f}" | |
| ) | |
| else: | |
| self.valid_loss.append(trainer.loss.item()) | |
| def after_fit(self, trainer): | |
| self.pbar.close() | |
| def plot_losses(self, save=True): | |
| fig, ax = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax[0].plot(self.train_loss) | |
| ax[0].set_title("train loss") | |
| ax[1].plot(self.valid_loss) | |
| ax[1].set_title("valid loss") | |
| if save: | |
| if not os.path.exists("./plots"): | |
| os.makedirs("./plots") | |
| plt.savefig("./plots/losses.png") | |
| else: | |
| plt.show() | |
| class DeviceCB(Callback): | |
| """Moves model and batches to device""" | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| def before_fit(self, trainer): | |
| if hasattr(trainer.model, "to"): | |
| trainer.model.to(self.device) | |
| def before_batch(self, trainer): | |
| trainer.batch = tuple(t.to(self.device) for t in trainer.batch) | |
| class Hook: | |
| """Registers PyTorch forward hook with provided function""" | |
| def __init__(self, name, mod, f): | |
| self.hook = mod.register_forward_hook(partial(f, self, name)) | |
| def remove(self): | |
| self.hook.remove() | |
| def __del__(self): | |
| self.remove() | |
| class Hooks(list): | |
| """List of hooks""" | |
| def __init__(self, mods, f): | |
| super().__init__([Hook(n, m, f) for n, m in mods]) | |
| def __enter__(self, *args): | |
| return self | |
| def __exit__(self, *args): | |
| self.remove() | |
| def __del__(self): | |
| self.remove() | |
| def __delitem__(self, i): | |
| self[i].remove() | |
| super().__delitem__(i) | |
| def remove(self): | |
| for h in self: | |
| h.remove() | |
| class HooksCB(Callback): | |
| """Appends hooks with some `hookfunc` to selected layers filtered by `mod_filter`.""" | |
| def __init__(self, hookfunc, mod_filter=lambda x: True): | |
| super().__init__() | |
| self.hookfunc = hookfunc | |
| self.mod_filter = mod_filter | |
| def before_fit(self, trainer): | |
| mods = [ | |
| (name, mod) | |
| for name, mod in trainer.model.named_modules() | |
| if self.mod_filter(mod) | |
| ] | |
| self.hooks = Hooks(mods, partial(self._hookfunc, trainer.training)) | |
| def _hookfunc(self, training, *args, **kwargs): | |
| if training: | |
| self.hookfunc(*args, **kwargs) | |
| def after_fit(self, trainer): | |
| self.hooks.remove() | |
| def __iter__(self): | |
| return iter(self.hooks) | |
| def __len__(self): | |
| return len(self.hooks) | |
| def append_stats(with_wandb, hook, name, mod, inp, outp): | |
| if not hasattr(hook, "stats"): | |
| hook.stats = {"mean": [], "std": [], "abs": []} | |
| acts = outp.detach().cpu() | |
| hook.stats["mean"].append(acts.mean().item()) | |
| hook.stats["std"].append(acts.std().item()) | |
| hook.stats["abs"].append(acts.abs().histc(40, 0, 10).tolist()) | |
| if with_wandb: | |
| wandb.log( | |
| { | |
| f"{name}/mean": acts.mean().item(), | |
| f"{name}/std": acts.std().item(), | |
| f"{name}/abs": wandb.Histogram(acts.abs().histc(40, 0, 10).tolist()), | |
| }, | |
| commit=False, | |
| ) | |
| def get_grid(n, figsize): | |
| return plt.subplots(round(n / 2), 2, figsize=figsize) | |
| class WandBCB(Callback): | |
| """Inits and logs to W&B. Every `wandb.log()` outside this callback should have property `commit=False` because this callback gathers all logs in given batch.""" | |
| order = math.inf # make sure that this callback will be called last | |
| def __init__( | |
| self, proj_name, model_path, run_name=None, notes=None, **additional_config | |
| ): | |
| self.proj_name = proj_name | |
| self.run_name = run_name | |
| self.model_path = model_path | |
| self.notes = notes | |
| self.additional_config = additional_config | |
| def before_fit(self, trainer): | |
| info = dict( | |
| project=self.proj_name, | |
| config={"lr": trainer.lr, "n_epochs": trainer.n_epochs}, | |
| ) | |
| if self.run_name is not None: | |
| info["name"] = self.run_name | |
| if self.notes is not None: | |
| info["notes"] = self.notes | |
| if self.additional_config is not None: | |
| info["config"] = {**info["config"], **self.additional_config} | |
| wandb.init(**info) | |
| wandb.watch(trainer.model, log="all") | |
| def after_loss(self, trainer): | |
| if trainer.training: | |
| wandb.log({"loss/train": trainer.loss.item()}, commit=False) | |
| else: | |
| wandb.log({"loss/valid": trainer.loss.item()}, commit=False) | |
| def after_batch(self, trainer): | |
| wandb.log({}, commit=True) | |
| def after_fit(self, trainer): | |
| torch.save(trainer.model.state_dict(), self.model_path) | |
| wandb.save(self.model_path) | |
| wandb.finish() | |
| class ActivationStatsCB(HooksCB): | |
| """Stores activation statistics of selected modules. Recommended only for debugging or visualizations, not for actual training because it significantly slows down training.""" | |
| def __init__(self, mod_filter=lambda x: x, with_wandb=False): | |
| super().__init__(partial(append_stats, with_wandb), mod_filter) | |
| def plot_stats(self, save=True): # plot output means & std devs of each module | |
| fig, axes = get_grid(2, figsize=(20, 10)) | |
| for h in self.hooks: | |
| for i, name in enumerate(["mean", "std dev"]): | |
| axes[i].plot(h.stats[i]) | |
| axes[i].set_title(name) | |
| plt.legend(range(len(self.hooks))) | |
| if save: | |
| if not os.path.exists("./plots"): | |
| os.makedirs("./plots") | |
| plt.savefig("./plots/mean_std_stats.png") | |
| else: | |
| plt.show() | |
| # plot "color dim" that shows abs values of outputs through training time (should be normally distributed - uniform gradient) | |
| def color_dim(self, save=True): | |
| fig, axes = get_grid(len(self.hooks), figsize=(20, 10)) | |
| for ax, h in zip(axes.flatten(), self.hooks): | |
| ax.set_ylim(0, 40) | |
| ax.imshow(self.get_hist(h), aspect="auto") | |
| if save: | |
| if not os.path.exists("./plots"): | |
| os.makedirs("./plots") | |
| plt.savefig("./plots/color_dim.png") | |
| else: | |
| plt.show() | |
| # plot % of dead neurons | |
| def dead_chart(self, save=True): | |
| fig, axes = get_grid(len(self.hooks), figsize=(20, 10)) | |
| for ax, h in zip(axes.flatten(), self.hooks): | |
| ax.plot(self.get_min(h)) | |
| ax.set_ylim(0, 1) | |
| if save: | |
| if not os.path.exists("./plots"): | |
| os.makedirs("./plots") | |
| plt.savefig("./plots/dead_neurons_perc.png") | |
| else: | |
| plt.show() | |
| # ratio of dead neurons (activations near 0) | |
| def get_min(self, h): | |
| h1 = torch.stack(h.stats[2]).t().float() | |
| return h1[0] / h1.sum(0) | |
| def get_hist(self, h): | |
| return torch.stack(h.stats[2]).t().float().log1p() | |
| class LRFinderCB(Callback): | |
| """Suggests an approx. good LR for a model. Usually you should choose value where loss is still decreasing (steepest slope), not the lowest value.""" | |
| def __init__(self, min_lr=1e-6, max_lr=1, max_mult=3, num_iter=100): | |
| self.min_lr = min_lr | |
| self.max_lr = max_lr | |
| self.max_mult = max_mult | |
| self.num_iter = num_iter | |
| self.lr_factor = (max_lr / min_lr) ** (1 / num_iter) | |
| def before_fit(self, trainer): | |
| self.lrs, self.losses = [], [] | |
| self.min = math.inf | |
| self.i = 0 | |
| trainer.opt.param_groups[0]["lr"] = self.min_lr | |
| def before_batch(self, trainer): | |
| trainer.opt.param_groups[0]["lr"] *= self.lr_factor | |
| def after_batch(self, trainer): | |
| if not trainer.training: | |
| raise CancelEpochException() | |
| self.lrs.append(trainer.opt.param_groups[0]["lr"]) | |
| loss = trainer.loss.to("cpu").item() | |
| self.losses.append(loss) | |
| if loss < self.min: | |
| self.min = loss | |
| self.i += 1 | |
| if ( | |
| math.isnan(loss) | |
| or (loss > self.min * self.max_mult) | |
| or (self.i > self.num_iter) | |
| ): | |
| raise CancelFitException() | |
| def plot_lrs(self, log=True, window=None): | |
| plt.plot(self.lrs, self.losses) # original loss curve | |
| plt.title("LR finder") | |
| if log: | |
| plt.xscale("log") | |
| if window is None: | |
| window = self.num_iter // 4 | |
| smoothed_losses = np.convolve( | |
| self.losses, np.ones(window) / window, mode="valid" | |
| ) | |
| gradients = np.gradient(smoothed_losses) | |
| min_gradient_idx = np.argmin(gradients) | |
| self.best_lr = self.lrs[min_gradient_idx + window // 2] | |
| plt.plot( | |
| self.best_lr, smoothed_losses[min_gradient_idx + window // 2], "ro" | |
| ) # recomended LR value point | |
| plt.text( | |
| self.best_lr, | |
| smoothed_losses[min_gradient_idx + window // 2], | |
| f"LR: {self.best_lr:.1e}", | |
| fontsize=12, | |
| ha="center", | |
| va="bottom", | |
| bbox=dict(facecolor="white"), | |
| ) | |
| plt.plot( | |
| self.lrs[window // 2 : -window // 2 + 1], smoothed_losses, alpha=0.5 | |
| ) # smoothed loss curve | |
| class AugmentCB(Callback): | |
| """Computes augmentation transformations on device (e.g. GPU) for faster training.""" | |
| def __init__(self, device="cpu", transform=None): | |
| super().__init__() | |
| self.device = device | |
| self.transform = transform | |
| def before_batch(self, trainer): | |
| trainer.batch = tuple( | |
| [ | |
| *[self.transform(t) for t in trainer.batch[: trainer.n_inp]], | |
| *trainer.batch[trainer.n_inp :], | |
| ] | |
| ) | |
| class MultiClassAccuracyCB(Callback): | |
| def __init__(self, with_wandb=False): | |
| self.all_acc = {"train": [], "valid": []} | |
| self.with_wandb = with_wandb | |
| def before_epoch(self, trainer): | |
| self.acc = [] | |
| def after_predict(self, trainer): | |
| self.acc = [] | |
| with torch.inference_mode(): | |
| self.acc.append( | |
| ( | |
| F.softmax(trainer.preds, dim=1).argmax(1) | |
| == trainer.batch[trainer.n_inp :][0] | |
| ).float() | |
| ) | |
| def after_epoch(self, trainer): | |
| final_acc = torch.hstack(self.acc).mean().item() | |
| if trainer.training: | |
| if self.with_wandb: | |
| wandb.log({"accuracy/train": final_acc}, commit=False) | |
| self.all_acc["train"].append(final_acc) | |
| else: | |
| if self.with_wandb: | |
| wandb.log({"accuracy/valid": final_acc}, commit=False) | |
| self.all_acc["valid"].append(final_acc) | |
| self.acc = [] | |
| def plot_acc(self): | |
| fig, axes = get_grid(2, (20, 10)) | |
| axes[0].plot(self.all_acc["train"]) | |
| axes[0].set_title("train acc") | |
| axes[1].plot(self.all_acc["valid"]) | |
| axes[1].set_title("valid acc") | |