Spaces:
Runtime error
Runtime error
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb. | |
| # %% auto 0 | |
| __all__ = ['SimpleVisual', 'validate', 'train'] | |
| # %% ../nbs/B1. Training.ipynb 2 | |
| import io | |
| import time | |
| import random | |
| from pathlib import Path | |
| from fastprogress import progress_bar, master_bar | |
| import fastprogress | |
| import numpy as np | |
| import pylab as plt | |
| import math | |
| import IPython | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data.dataloader import DataLoader | |
| from torch.profiler import record_function | |
| import webdataset as wds | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision('medium') | |
| # %% ../nbs/B1. Training.ipynb 3 | |
| class SimpleVisual: | |
| def __init__ (self, model, masterbar, total_steps): | |
| self.model = model | |
| self.masterbar = masterbar | |
| self.total_steps = total_steps | |
| self.epochs = total_steps // masterbar.main_bar.total | |
| gs = plt.GridSpec(2, 1, height_ratios=[3,1]) | |
| graph_fig = plt.figure(figsize=(10,6)) | |
| self.graph_fig = graph_fig | |
| self.loss_p = graph_fig.add_subplot(gs[0]) | |
| self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) | |
| self.lr_p.tick_params('x', labelbottom=False) | |
| self.graph_out = None | |
| self.its = [] | |
| self.train_losses = [] | |
| self.val_losses = [] | |
| self.lr_history = [] | |
| def show(self): | |
| self.start_t = time.time() | |
| self.masterbar.write(["samples", "train", "val", "time"], table=True) | |
| self.graph_out = display(self.graph_fig, display_id=True, clear=True) | |
| def hide(self): | |
| if self.graph_out is not None: | |
| self.graph_out.update(IPython.display.HTML('')) | |
| def plot(self): | |
| loss_p, lr_p = self.loss_p, self.lr_p | |
| loss_p.clear() | |
| loss_p.plot(self.its, self.train_losses) | |
| loss_p.plot(self.its, self.val_losses) | |
| loss_p.set_xlim(0, self.total_steps) | |
| loss_p.set_yscale('log') | |
| lr_p.clear() | |
| lrs = np.array(self.lr_history) | |
| lr_p.plot(self.its, lrs) | |
| self.graph_out.update(self.graph_fig) | |
| def add_data(self, it, lr, train_loss, val_los): | |
| self.its.append(it) | |
| self.train_losses.append(train_loss) | |
| self.val_losses.append(val_los) | |
| self.lr_history.append(lr) | |
| self.plot() | |
| def add_table_row(self, it, avg_train_loss, val_loss): | |
| elapsed_t = time.time() - self.start_t | |
| self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) | |
| def on_iter(self, bar, it, avg_train_loss, val_loss): | |
| epoch = math.ceil(it / self.total_steps * self.epochs) | |
| bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" | |
| # %% ../nbs/B1. Training.ipynb 4 | |
| # FIXME: we need to keep this synchronised with the validation code below... | |
| def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"): | |
| if isinstance(val, torch.utils.data.IterableDataset): | |
| val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
| .unbatched().shuffle(1024).batched(bs) | |
| else: | |
| val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) | |
| with torch.no_grad(): | |
| val_loss = 0 | |
| val_samples = 0 | |
| for args in val_loader: | |
| args = [x.to(device, non_blocking=True) for x in args] | |
| with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
| ps, loss = model(*args) | |
| N = args[0].shape[0] | |
| val_loss += loss.mean().item() * N | |
| val_samples += N | |
| val_loss = val_loss / val_samples | |
| return val_loss | |
| # %% ../nbs/B1. Training.ipynb 5 | |
| def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False, | |
| weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None, | |
| dl_workers=8, visual_class = SimpleVisual, profiler=None, | |
| run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None, | |
| device="cuda", trainable_params=None): | |
| if chkpt_every_iters is None: | |
| chkpt_every_iters = table_row_every_iters | |
| mb = master_bar(range(epochs)) | |
| if isinstance(train, torch.utils.data.IterableDataset): | |
| pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs))) | |
| visual = visual_class(model, mb, epochs * train.total_samples) | |
| # pct_start = min(0.3, warmup_steps / (epochs * len(train))) | |
| # visual = visual_class(model, mb, epochs*len(train)*bs) | |
| else: | |
| pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs)) | |
| visual = visual_class(model, mb, epochs*len(train)) | |
| model.visual = visual | |
| Path(checkpoint_path).mkdir(exist_ok=True) | |
| if isinstance(train, torch.utils.data.IterableDataset): | |
| # train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False) | |
| # val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False) | |
| train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
| .unbatched().shuffle(1024).batched(bs, partial=False) | |
| val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
| .unbatched().shuffle(1024).batched(bs) | |
| else: | |
| train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True) | |
| val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) | |
| val_loss = torch.nan | |
| avg_train_loss = torch.nan | |
| if hasattr(model, 'setup'): | |
| model.setup(device) | |
| try: | |
| scheduler = None | |
| if trainable_params is None: trainable_params = model.parameters() | |
| all_params = set(trainable_params) | |
| customized_params = set() | |
| groups = [] | |
| group_map = {} | |
| for name,m in model.named_modules(): | |
| if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): | |
| m_trainable = [x for x in m.parameters() if x in all_params] | |
| if not m_trainable: continue | |
| customized_params |= set(m_trainable) | |
| m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay | |
| m_lr = lr * getattr(m, 'lr_scale', 1) | |
| group = group_map.get((m_wd, m_lr), None) | |
| if not group: | |
| group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} | |
| groups.append(group) | |
| group_map[(m_wd, m_lr)] = group | |
| group['params'] += m_trainable | |
| group['names'].append(name) | |
| other_params = all_params - customized_params | |
| if other_params: | |
| groups = groups + [ | |
| {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, | |
| ] | |
| optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups) | |
| model._optimizer = optimizer | |
| scaler = torch.cuda.amp.GradScaler(enabled=half) | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs, | |
| max_lr=[pg.get('lr', lr) for pg in groups], | |
| final_div_factor=25) | |
| it = 0 | |
| next_val_it = it + 50 | |
| next_chkpt_it = chkpt_every_iters | |
| next_table_it = table_row_every_iters | |
| visual.show() | |
| running_loss = [0] | |
| for epoch in mb: | |
| bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb) | |
| for args in bar: | |
| with record_function("forward"): | |
| args = [x.to(device, non_blocking=True) for x in args] | |
| # zero the parameter gradients | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
| ps, loss = model(*args) | |
| loss = loss.mean() | |
| with record_function("backward"): | |
| scaler.scale(loss).backward() | |
| if clip_gradient_norm: | |
| scaler.unscale_(optimizer) | |
| # Since the gradients of optimizer's assigned params are unscaled, clips as usual: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| if profiler is not None: profiler.step() | |
| with record_function("running_loss"): | |
| running_loss.append(loss.item()) | |
| running_loss = running_loss[-5:] | |
| avg_train_loss = sum(running_loss)/len(running_loss) | |
| if it >= next_chkpt_it: | |
| with record_function("checkpoint"): | |
| next_chkpt_it += chkpt_every_iters | |
| torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt') | |
| if it >= next_val_it: | |
| next_val_it += run_valid_every_iters | |
| with record_function("validation"): | |
| with record_function("model.eval"): | |
| model.eval() | |
| with torch.no_grad(): | |
| val_loss = 0 | |
| val_samples = 0 | |
| for args in val_loader: | |
| args = [x.to(device, non_blocking=True) for x in args] | |
| with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
| ps, loss = model(*args) | |
| N = args[0].shape[0] | |
| val_loss += loss.mean().item() * N | |
| val_samples += N | |
| val_loss = val_loss / val_samples | |
| with record_function("model.train"): | |
| model.train() | |
| with record_function("plotting"): | |
| visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss) | |
| if it >= next_table_it: | |
| visual.add_table_row(it, avg_train_loss, val_loss) | |
| next_table_it += table_row_every_iters | |
| it += bs | |
| visual.on_iter(bar, it, avg_train_loss, val_loss) | |
| except KeyboardInterrupt: | |
| mb.write(f"interrupted") | |
| mb.show() | |
| pass | |
| finally: | |
| visual.add_table_row(it, avg_train_loss, val_loss) | |
| mb.show() | |
| visual.hide() | |