Spaces:
Runtime error
Runtime error
| import types | |
| import time | |
| import random | |
| import clip | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from argparse import ArgumentParser | |
| import pytorch_lightning as pl | |
| from data import get_dataset, get_available_datasets | |
| from encoding.models import get_segmentation_model | |
| from encoding.nn import SegmentationLosses | |
| from encoding.utils import batch_pix_accuracy, batch_intersection_union | |
| # add mixed precision | |
| import torch.cuda.amp as amp | |
| import numpy as np | |
| from encoding.utils import SegmentationMetric | |
| class LSegmentationModule(pl.LightningModule): | |
| def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): | |
| super().__init__() | |
| self.data_path = data_path | |
| self.batch_size = batch_size | |
| self.base_lr = base_lr / 16 * batch_size | |
| self.lr = self.base_lr | |
| self.epochs = max_epochs | |
| self.other_kwargs = kwargs | |
| self.enabled = False #True mixed precision will make things complicated and leading to NAN error | |
| self.scaler = amp.GradScaler(enabled=self.enabled) | |
| def forward(self, x): | |
| return self.net(x) | |
| def evaluate(self, x, target=None): | |
| pred = self.net.forward(x) | |
| if isinstance(pred, (tuple, list)): | |
| pred = pred[0] | |
| if target is None: | |
| return pred | |
| correct, labeled = batch_pix_accuracy(pred.data, target.data) | |
| inter, union = batch_intersection_union(pred.data, target.data, self.nclass) | |
| return correct, labeled, inter, union | |
| def evaluate_random(self, x, labelset, target=None): | |
| pred = self.net.forward(x, labelset) | |
| if isinstance(pred, (tuple, list)): | |
| pred = pred[0] | |
| if target is None: | |
| return pred | |
| correct, labeled = batch_pix_accuracy(pred.data, target.data) | |
| inter, union = batch_intersection_union(pred.data, target.data, self.nclass) | |
| return correct, labeled, inter, union | |
| def training_step(self, batch, batch_nb): | |
| img, target = batch | |
| with amp.autocast(enabled=self.enabled): | |
| out = self(img) | |
| multi_loss = isinstance(out, tuple) | |
| if multi_loss: | |
| loss = self.criterion(*out, target) | |
| else: | |
| loss = self.criterion(out, target) | |
| loss = self.scaler.scale(loss) | |
| final_output = out[0] if multi_loss else out | |
| train_pred, train_gt = self._filter_invalid(final_output, target) | |
| if train_gt.nelement() != 0: | |
| self.train_accuracy(train_pred, train_gt) | |
| self.log("train_loss", loss) | |
| return loss | |
| def training_epoch_end(self, outs): | |
| self.log("train_acc_epoch", self.train_accuracy.compute()) | |
| def validation_step(self, batch, batch_nb): | |
| img, target = batch | |
| out = self(img) | |
| multi_loss = isinstance(out, tuple) | |
| if multi_loss: | |
| val_loss = self.criterion(*out, target) | |
| else: | |
| val_loss = self.criterion(out, target) | |
| final_output = out[0] if multi_loss else out | |
| valid_pred, valid_gt = self._filter_invalid(final_output, target) | |
| self.val_iou.update(target, final_output) | |
| pixAcc, iou = self.val_iou.get() | |
| self.log("val_loss_step", val_loss) | |
| self.log("pix_acc_step", pixAcc) | |
| self.log( | |
| "val_acc_step", | |
| self.val_accuracy(valid_pred, valid_gt), | |
| ) | |
| self.log("val_iou", iou) | |
| def validation_epoch_end(self, outs): | |
| pixAcc, iou = self.val_iou.get() | |
| self.log("val_acc_epoch", self.val_accuracy.compute()) | |
| self.log("val_iou_epoch", iou) | |
| self.log("pix_acc_epoch", pixAcc) | |
| self.val_iou.reset() | |
| def _filter_invalid(self, pred, target): | |
| valid = target != self.other_kwargs["ignore_index"] | |
| _, mx = torch.max(pred, dim=1) | |
| return mx[valid], target[valid] | |
| def configure_optimizers(self): | |
| params_list = [ | |
| {"params": self.net.pretrained.parameters(), "lr": self.base_lr}, | |
| ] | |
| if hasattr(self.net, "scratch"): | |
| print("Found output scratch") | |
| params_list.append( | |
| {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if hasattr(self.net, "auxlayer"): | |
| print("Found auxlayer") | |
| params_list.append( | |
| {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if hasattr(self.net, "scale_inv_conv"): | |
| print(self.net.scale_inv_conv) | |
| print("Found scaleinv layers") | |
| params_list.append( | |
| { | |
| "params": self.net.scale_inv_conv.parameters(), | |
| "lr": self.base_lr * 10, | |
| } | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if self.other_kwargs["midasproto"]: | |
| print("Using midas optimization protocol") | |
| opt = torch.optim.Adam( | |
| params_list, | |
| lr=self.base_lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=self.other_kwargs["weight_decay"], | |
| ) | |
| sch = torch.optim.lr_scheduler.LambdaLR( | |
| opt, lambda x: pow(1.0 - x / self.epochs, 0.9) | |
| ) | |
| else: | |
| opt = torch.optim.SGD( | |
| params_list, | |
| lr=self.base_lr, | |
| momentum=0.9, | |
| weight_decay=self.other_kwargs["weight_decay"], | |
| ) | |
| sch = torch.optim.lr_scheduler.LambdaLR( | |
| opt, lambda x: pow(1.0 - x / self.epochs, 0.9) | |
| ) | |
| return [opt], [sch] | |
| def train_dataloader(self): | |
| return torch.utils.data.DataLoader( | |
| self.trainset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=16, | |
| worker_init_fn=lambda x: random.seed(time.time() + x), | |
| ) | |
| def val_dataloader(self): | |
| return torch.utils.data.DataLoader( | |
| self.valset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=16, | |
| ) | |
| def get_trainset(self, dset, augment=False, **kwargs): | |
| print(kwargs) | |
| if augment == True: | |
| mode = "train_x" | |
| else: | |
| mode = "train" | |
| print(mode) | |
| dset = get_dataset( | |
| dset, | |
| root=self.data_path, | |
| split="train", | |
| mode=mode, | |
| transform=self.train_transform, | |
| **kwargs | |
| ) | |
| self.num_classes = dset.num_class | |
| self.train_accuracy = pl.metrics.Accuracy() | |
| return dset | |
| def get_valset(self, dset, augment=False, **kwargs): | |
| self.val_accuracy = pl.metrics.Accuracy() | |
| self.val_iou = SegmentationMetric(self.num_classes) | |
| if augment == True: | |
| mode = "val_x" | |
| else: | |
| mode = "val" | |
| print(mode) | |
| return get_dataset( | |
| dset, | |
| root=self.data_path, | |
| split="val", | |
| mode=mode, | |
| transform=self.val_transform, | |
| **kwargs | |
| ) | |
| def get_criterion(self, **kwargs): | |
| return SegmentationLosses( | |
| se_loss=kwargs["se_loss"], | |
| aux=kwargs["aux"], | |
| nclass=self.num_classes, | |
| se_weight=kwargs["se_weight"], | |
| aux_weight=kwargs["aux_weight"], | |
| ignore_index=kwargs["ignore_index"], | |
| ) | |
| def add_model_specific_args(parent_parser): | |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
| parser.add_argument( | |
| "--data_path", type=str, help="path where dataset is stored" | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| choices=get_available_datasets(), | |
| default="ade20k", | |
| help="dataset to train on", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=16, help="size of the batches" | |
| ) | |
| parser.add_argument( | |
| "--base_lr", type=float, default=0.004, help="learning rate" | |
| ) | |
| parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum") | |
| parser.add_argument( | |
| "--weight_decay", type=float, default=1e-4, help="weight_decay" | |
| ) | |
| parser.add_argument( | |
| "--aux", action="store_true", default=False, help="Auxilary Loss" | |
| ) | |
| parser.add_argument( | |
| "--aux-weight", | |
| type=float, | |
| default=0.2, | |
| help="Auxilary loss weight (default: 0.2)", | |
| ) | |
| parser.add_argument( | |
| "--se-loss", | |
| action="store_true", | |
| default=False, | |
| help="Semantic Encoding Loss SE-loss", | |
| ) | |
| parser.add_argument( | |
| "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" | |
| ) | |
| parser.add_argument( | |
| "--midasproto", action="store_true", default=False, help="midasprotocol" | |
| ) | |
| parser.add_argument( | |
| "--ignore_index", | |
| type=int, | |
| default=-1, | |
| help="numeric value of ignore label in gt", | |
| ) | |
| parser.add_argument( | |
| "--augment", | |
| action="store_true", | |
| default=False, | |
| help="Use extended augmentations", | |
| ) | |
| return parser | |