Spaces:
Runtime error
Runtime error
| import types | |
| import time | |
| import random | |
| import clip | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from argparse import ArgumentParser | |
| import pytorch_lightning as pl | |
| from encoding.models import get_segmentation_model | |
| from encoding.nn import SegmentationLosses | |
| from encoding.utils import batch_pix_accuracy, batch_intersection_union | |
| # add mixed precision | |
| import torch.cuda.amp as amp | |
| import numpy as np | |
| from encoding.utils.metrics import SegmentationMetric | |
| # get fewshot dataloader | |
| from fewshot_data.model.hsnet import HypercorrSqueezeNetwork | |
| from fewshot_data.common.logger import Logger, AverageMeter | |
| from fewshot_data.common.evaluation import Evaluator | |
| from fewshot_data.common import utils | |
| from fewshot_data.data.dataset import FSSDataset | |
| class Fewshot_args: | |
| datapath = 'fewshot_data/Datasets_HSN' | |
| benchmark = 'pascal' | |
| logpath = '' | |
| nworker = 8 | |
| bsz = 20 | |
| fold = 0 | |
| class LSegmentationModuleZS(pl.LightningModule): | |
| def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.base_lr = base_lr / 16 * batch_size | |
| self.lr = self.base_lr | |
| self.epochs = max_epochs | |
| self.other_kwargs = kwargs | |
| self.enabled = False #True mixed precision will make things complicated and leading to NAN error | |
| self.scaler = amp.GradScaler(enabled=self.enabled) | |
| # for whether fix the encoder or not | |
| self.fixed_encoder = True if kwargs["use_pretrained"] in ['clip_fixed'] else False | |
| # fewshot hyperparameters | |
| self.cross_entropy_loss = nn.CrossEntropyLoss() | |
| self.args = self.get_fewshot_args() | |
| if data_path: | |
| self.args.datapath = data_path | |
| self.args.logpath = self.other_kwargs["logpath"] | |
| self.args.benchmark = dataset | |
| self.args.bsz = self.batch_size | |
| self.args.fold = self.other_kwargs["fold"] | |
| self.args.nshot = self.other_kwargs["nshot"] | |
| self.args.finetune_mode = self.other_kwargs["finetune_mode"] | |
| Logger.initialize(self.args, training=True) | |
| Evaluator.initialize() | |
| if kwargs["backbone"] in ["clip_resnet101"]: | |
| FSSDataset.initialize(img_size=480, datapath=self.args.datapath, use_original_imgsize=False, imagenet_norm=True) | |
| else: | |
| FSSDataset.initialize(img_size=480, datapath=self.args.datapath, use_original_imgsize=False) | |
| self.best_val_miou = float('-inf') | |
| self.num_classes = 2 | |
| self.labels = ['others', ''] | |
| self.fewshot_trn_loss = 100 | |
| self.fewshot_trn_miou = 0 | |
| self.fewshot_trn_fb_iou = 0 | |
| def get_fewshot_args(self): | |
| return Fewshot_args() | |
| def forward(self, x, class_info): | |
| return self.net(x, class_info) | |
| def training_step(self, batch, batch_nb): | |
| if self.args.finetune_mode: | |
| if self.args.nshot == 5: | |
| bshape = batch['support_imgs'].shape | |
| img = batch['support_imgs'].view(-1, bshape[2], bshape[3], bshape[4]) | |
| target = batch['support_masks'].view(-1, bshape[3], bshape[4]) | |
| class_info = batch['class_id'] | |
| for i in range(1, 5): | |
| class_info = torch.cat([class_info, batch['class_id']]) | |
| with amp.autocast(enabled=self.enabled): | |
| out = self(img, class_info) | |
| loss = self.criterion(out, target) | |
| loss = self.scaler.scale(loss) | |
| self.log("train_loss", loss) | |
| # 3. Evaluate prediction | |
| if self.args.benchmark == 'pascal' and batch['support_ignore_idxs'] is not None: | |
| query_ignore_idx = batch['support_ignore_idxs'].view(-1, bshape[3], bshape[4]) | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target) | |
| else: | |
| img = batch['support_imgs'].squeeze(1) | |
| target = batch['support_masks'].squeeze(1) | |
| class_info = batch['class_id'] | |
| with amp.autocast(enabled=self.enabled): | |
| out = self(img, class_info) | |
| loss = self.criterion(out, target) | |
| loss = self.scaler.scale(loss) | |
| self.log("train_loss", loss) | |
| # 3. Evaluate prediction | |
| if self.args.benchmark == 'pascal' and batch['support_ignore_idxs'] is not None: | |
| query_ignore_idx = batch['support_ignore_idxs'].squeeze(1) | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target) | |
| else: | |
| img = torch.cat([batch['support_imgs'].squeeze(1), batch['query_img']], dim=0) | |
| target = torch.cat([batch['support_masks'].squeeze(1), batch['query_mask']], dim=0) | |
| class_info=torch.cat([batch['class_id'], batch['class_id']], dim=0) | |
| with amp.autocast(enabled=self.enabled): | |
| out = self(img, class_info) | |
| loss = self.criterion(out, target) | |
| loss = self.scaler.scale(loss) | |
| self.log("train_loss", loss) | |
| # 3. Evaluate prediction | |
| if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: | |
| query_ignore_idx = torch.cat([batch['support_ignore_idxs'].squeeze(1), batch['query_ignore_idx']], dim=0) | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target) | |
| self.train_average_meter.update(area_inter, area_union, class_info, loss.detach().clone()) | |
| if self.global_rank == 0: | |
| return_value = self.train_average_meter.write_process(batch_nb, self.len_train_dataloader, self.current_epoch, write_batch_idx=50) | |
| if return_value is not None: | |
| iou, fb_iou = return_value | |
| self.log("fewshot_train_iou", iou) | |
| self.log("fewshot_trainl_fb_iou", fb_iou) | |
| return loss | |
| def training_epoch_end(self, outs): | |
| if self.global_rank == 0: | |
| self.train_average_meter.write_result('Training', self.current_epoch) | |
| self.fewshot_trn_loss = utils.mean(self.train_average_meter.loss_buf) | |
| self.fewshot_trn_miou, self.fewshot_trn_fb_iou = self.train_average_meter.compute_iou() | |
| self.log("fewshot_trn_loss", self.fewshot_trn_loss) | |
| self.log("fewshot_trn_miou", self.fewshot_trn_miou) | |
| self.log("fewshot_trn_fb_iou", self.fewshot_trn_fb_iou) | |
| def validation_step(self, batch, batch_nb): | |
| if self.args.finetune_mode and self.args.nshot == 5: | |
| bshape = batch['query_img'].shape | |
| img = batch['query_img'].view(-1, bshape[2], bshape[3], bshape[4]) | |
| target = batch['query_mask'].view(-1, bshape[3], bshape[4]) | |
| class_info = batch['class_id'] | |
| for i in range(1, 5): | |
| class_info = torch.cat([class_info, batch['class_id']]) | |
| out = self(img, class_info) | |
| val_loss = self.criterion(out, target) | |
| # 3. Evaluate prediction | |
| if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: | |
| query_ignore_idx = batch['query_ignore_idx'].view(-1, bshape[3], bshape[4]) | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target) | |
| else: | |
| img = batch['query_img'].squeeze(1) | |
| target = batch['query_mask'].squeeze(1) | |
| class_info = batch['class_id'] | |
| out = self(img, class_info) | |
| val_loss = self.criterion(out, target) | |
| # 3. Evaluate prediction | |
| if self.args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: | |
| query_ignore_idx = batch['query_ignore_idx'].squeeze(1) | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(out.argmax(dim=1), target) | |
| self.val_average_meter.update(area_inter, area_union, class_info, val_loss.detach().clone()) | |
| if self.global_rank == 0: | |
| return_value = self.val_average_meter.write_process(batch_nb, self.len_val_dataloader, self.current_epoch, write_batch_idx=50) | |
| if return_value is not None: | |
| iou, fb_iou = return_value | |
| self.log("fewshot_val_iou", iou) | |
| self.log("fewshot_val_fb_iou", fb_iou) | |
| def validation_epoch_end(self, outs): | |
| if self.global_rank == 0: | |
| self.val_average_meter.write_result('Validation', self.current_epoch) | |
| val_loss = utils.mean(self.val_average_meter.loss_buf) | |
| val_miou, val_fb_iou = self.val_average_meter.compute_iou() | |
| self.log("fewshot_val_loss", val_loss) | |
| self.log("fewshot_val_miou", val_miou) | |
| self.log("fewshot_val_fb_iou", val_fb_iou) | |
| if self.global_rank == 0: | |
| Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': self.fewshot_trn_loss, 'val_loss': val_loss}, self.current_epoch) | |
| Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': self.fewshot_trn_miou, 'val_miou': val_miou}, self.current_epoch) | |
| Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': self.fewshot_trn_fb_iou, 'val_fb_iou': val_fb_iou}, self.current_epoch) | |
| Logger.tbd_writer.flush() | |
| if self.current_epoch + 1 == self.epochs: | |
| Logger.tbd_writer.close() | |
| Logger.info('==================== Finished Training ====================') | |
| threshold_epoch = 3 | |
| if self.args.benchmark in ['pascal', 'coco'] and self.current_epoch >= threshold_epoch: | |
| print('End this loop!') | |
| exit() | |
| def configure_optimizers(self): | |
| # if we want to fix the encoder | |
| if self.fixed_encoder: | |
| params_list = [ | |
| {"params": self.net.pretrained.model.parameters(), "lr": 0}, | |
| ] | |
| params_list.append( | |
| {"params": self.net.pretrained.act_postprocess1.parameters(), "lr": self.base_lr} | |
| ) | |
| params_list.append( | |
| {"params": self.net.pretrained.act_postprocess2.parameters(), "lr": self.base_lr} | |
| ) | |
| params_list.append( | |
| {"params": self.net.pretrained.act_postprocess3.parameters(), "lr": self.base_lr} | |
| ) | |
| params_list.append( | |
| {"params": self.net.pretrained.act_postprocess4.parameters(), "lr": self.base_lr} | |
| ) | |
| else: | |
| params_list = [ | |
| {"params": self.net.pretrained.parameters(), "lr": self.base_lr}, | |
| ] | |
| if hasattr(self.net, "scratch"): | |
| print("Found output scratch") | |
| params_list.append( | |
| {"params": self.net.scratch.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if hasattr(self.net, "auxlayer"): | |
| print("Found auxlayer") | |
| params_list.append( | |
| {"params": self.net.auxlayer.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if hasattr(self.net, "scale_inv_conv"): | |
| print(self.net.scale_inv_conv) | |
| print("Found scaleinv layers") | |
| params_list.append( | |
| { | |
| "params": self.net.scale_inv_conv.parameters(), | |
| "lr": self.base_lr * 10, | |
| } | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale2_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale3_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| params_list.append( | |
| {"params": self.net.scale4_conv.parameters(), "lr": self.base_lr * 10} | |
| ) | |
| if self.other_kwargs["midasproto"]: | |
| print("Using midas optimization protocol") | |
| opt = torch.optim.Adam( | |
| params_list, | |
| lr=self.base_lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=self.other_kwargs["weight_decay"], | |
| ) | |
| sch = torch.optim.lr_scheduler.LambdaLR( | |
| opt, lambda x: pow(1.0 - x / self.epochs, 0.9) | |
| ) | |
| else: | |
| opt = torch.optim.SGD( | |
| params_list, | |
| lr=self.base_lr, | |
| momentum=0.9, | |
| weight_decay=self.other_kwargs["weight_decay"], | |
| ) | |
| sch = torch.optim.lr_scheduler.LambdaLR( | |
| opt, lambda x: pow(1.0 - x / self.epochs, 0.9) | |
| ) | |
| return [opt], [sch] | |
| def train_dataloader(self): | |
| if self.args.finetune_mode: | |
| dataloader = FSSDataset.build_dataloader( | |
| self.args.benchmark, | |
| self.args.bsz, | |
| self.args.nworker, | |
| self.args.fold, | |
| 'test', | |
| self.args.nshot) | |
| else: | |
| dataloader = FSSDataset.build_dataloader( | |
| self.args.benchmark, | |
| self.args.bsz, | |
| self.args.nworker, | |
| self.args.fold, | |
| 'trn') | |
| self.len_train_dataloader = len(dataloader) // torch.cuda.device_count() | |
| self.train_average_meter = AverageMeter(dataloader.dataset) | |
| return dataloader | |
| def val_dataloader(self): | |
| self.val_iou = SegmentationMetric(self.num_classes) | |
| if self.args.finetune_mode: | |
| dataloader = FSSDataset.build_dataloader( | |
| self.args.benchmark, | |
| self.args.bsz, | |
| self.args.nworker, | |
| self.args.fold, | |
| 'test', | |
| self.args.nshot) | |
| else: | |
| dataloader = FSSDataset.build_dataloader( | |
| self.args.benchmark, | |
| self.args.bsz, | |
| self.args.nworker, | |
| self.args.fold, | |
| 'val') | |
| self.len_val_dataloader = len(dataloader) // torch.cuda.device_count() | |
| self.val_average_meter = AverageMeter(dataloader.dataset) | |
| return dataloader | |
| def criterion(self, logit_mask, gt_mask): | |
| bsz = logit_mask.size(0) | |
| logit_mask = logit_mask.view(bsz, 2, -1) | |
| gt_mask = gt_mask.view(bsz, -1).long() | |
| return self.cross_entropy_loss(logit_mask, gt_mask) | |
| def add_model_specific_args(parent_parser): | |
| parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| default='', | |
| help="path where dataset is stored" | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default='pascal', | |
| choices=['pascal', 'coco', 'fss'], | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=20, help="size of the batches" | |
| ) | |
| parser.add_argument( | |
| "--base_lr", type=float, default=0.004, help="learning rate" | |
| ) | |
| parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum") | |
| parser.add_argument( | |
| "--weight_decay", type=float, default=1e-4, help="weight_decay" | |
| ) | |
| parser.add_argument( | |
| "--aux", action="store_true", default=False, help="Auxilary Loss" | |
| ) | |
| parser.add_argument( | |
| "--aux-weight", | |
| type=float, | |
| default=0.2, | |
| help="Auxilary loss weight (default: 0.2)", | |
| ) | |
| parser.add_argument( | |
| "--se-loss", | |
| action="store_true", | |
| default=False, | |
| help="Semantic Encoding Loss SE-loss", | |
| ) | |
| parser.add_argument( | |
| "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" | |
| ) | |
| parser.add_argument( | |
| "--midasproto", action="store_true", default=False, help="midasprotocol" | |
| ) | |
| parser.add_argument( | |
| "--ignore_index", | |
| type=int, | |
| default=-1, | |
| help="numeric value of ignore label in gt", | |
| ) | |
| parser.add_argument( | |
| "--augment", | |
| action="store_true", | |
| default=False, | |
| help="Use extended augmentations", | |
| ) | |
| parser.add_argument( | |
| "--use_relabeled", | |
| action="store_true", | |
| default=False, | |
| help="Use extended augmentations", | |
| ) | |
| parser.add_argument( | |
| "--nworker", | |
| type=int, | |
| default=8 | |
| ) | |
| parser.add_argument( | |
| "--fold", | |
| type=int, | |
| default=0, | |
| choices=[0, 1, 2, 3] | |
| ) | |
| parser.add_argument( | |
| "--logpath", | |
| type=str, | |
| default='' | |
| ) | |
| parser.add_argument( | |
| "--nshot", | |
| type=int, | |
| default=0 #1 | |
| ) | |
| parser.add_argument( | |
| "--finetune_mode", | |
| action="store_true", | |
| default=False, | |
| help="whether finetune or not" | |
| ) | |
| return parser | |