Spaces:
Runtime error
Runtime error
| import os | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from modules.lseg_module_zs import LSegModuleZS | |
| from additional_utils.models import LSeg_MultiEvalModule | |
| from fewshot_data.common.logger import Logger, AverageMeter | |
| from fewshot_data.common.vis import Visualizer | |
| from fewshot_data.common.evaluation import Evaluator | |
| from fewshot_data.common import utils | |
| from fewshot_data.data.dataset import FSSDataset | |
| class Options: | |
| def __init__(self): | |
| parser = argparse.ArgumentParser(description="PyTorch Segmentation") | |
| # model and dataset | |
| parser.add_argument( | |
| "--model", type=str, default="encnet", help="model name (default: encnet)" | |
| ) | |
| parser.add_argument( | |
| "--backbone", | |
| type=str, | |
| default="resnet50", | |
| help="backbone name (default: resnet50)", | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default="ade20k", | |
| help="dataset name (default: pascal12)", | |
| ) | |
| parser.add_argument( | |
| "--workers", type=int, default=16, metavar="N", help="dataloader threads" | |
| ) | |
| parser.add_argument( | |
| "--base-size", type=int, default=520, help="base image size" | |
| ) | |
| parser.add_argument( | |
| "--crop-size", type=int, default=480, help="crop image size" | |
| ) | |
| parser.add_argument( | |
| "--train-split", | |
| type=str, | |
| default="train", | |
| help="dataset train split (default: train)", | |
| ) | |
| # training hyper params | |
| parser.add_argument( | |
| "--aux", action="store_true", default=False, help="Auxilary Loss" | |
| ) | |
| parser.add_argument( | |
| "--se-loss", | |
| action="store_true", | |
| default=False, | |
| help="Semantic Encoding Loss SE-loss", | |
| ) | |
| parser.add_argument( | |
| "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=16, | |
| metavar="N", | |
| help="input batch size for \ | |
| training (default: auto)", | |
| ) | |
| parser.add_argument( | |
| "--test-batch-size", | |
| type=int, | |
| default=16, | |
| metavar="N", | |
| help="input batch size for \ | |
| testing (default: same as batch size)", | |
| ) | |
| # cuda, seed and logging | |
| parser.add_argument( | |
| "--no-cuda", | |
| action="store_true", | |
| default=False, | |
| help="disables CUDA training", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" | |
| ) | |
| # checking point | |
| parser.add_argument( | |
| "--weights", type=str, default=None, help="checkpoint to test" | |
| ) | |
| # evaluation option | |
| parser.add_argument( | |
| "--eval", action="store_true", default=False, help="evaluating mIoU" | |
| ) | |
| parser.add_argument( | |
| "--acc-bn", | |
| action="store_true", | |
| default=False, | |
| help="Re-accumulate BN statistics", | |
| ) | |
| parser.add_argument( | |
| "--test-val", | |
| action="store_true", | |
| default=False, | |
| help="generate masks on val set", | |
| ) | |
| parser.add_argument( | |
| "--no-val", | |
| action="store_true", | |
| default=False, | |
| help="skip validation during training", | |
| ) | |
| parser.add_argument( | |
| "--module", | |
| default='', | |
| help="select model definition", | |
| ) | |
| # test option | |
| parser.add_argument( | |
| "--no-scaleinv", | |
| dest="scale_inv", | |
| default=True, | |
| action="store_false", | |
| help="turn off scaleinv layers", | |
| ) | |
| parser.add_argument( | |
| "--widehead", default=False, action="store_true", help="wider output head" | |
| ) | |
| parser.add_argument( | |
| "--widehead_hr", | |
| default=False, | |
| action="store_true", | |
| help="wider output head", | |
| ) | |
| parser.add_argument( | |
| "--ignore_index", | |
| type=int, | |
| default=-1, | |
| help="numeric value of ignore label in gt", | |
| ) | |
| parser.add_argument( | |
| "--jobname", | |
| type=str, | |
| default="default", | |
| help="select which dataset", | |
| ) | |
| parser.add_argument( | |
| "--no-strict", | |
| dest="strict", | |
| default=True, | |
| action="store_false", | |
| help="no-strict copy the model", | |
| ) | |
| parser.add_argument( | |
| "--use_pretrained", | |
| type=str, | |
| default="True", | |
| help="whether use the default model to intialize the model", | |
| ) | |
| parser.add_argument( | |
| "--arch_option", | |
| type=int, | |
| default=0, | |
| help="which kind of architecture to be used", | |
| ) | |
| # fewshot options | |
| parser.add_argument( | |
| '--nshot', | |
| type=int, | |
| default=1 | |
| ) | |
| parser.add_argument( | |
| '--fold', | |
| type=int, | |
| default=0, | |
| choices=[0, 1, 2, 3] | |
| ) | |
| parser.add_argument( | |
| '--nworker', | |
| type=int, | |
| default=0 | |
| ) | |
| parser.add_argument( | |
| '--bsz', | |
| type=int, | |
| default=1 | |
| ) | |
| parser.add_argument( | |
| '--benchmark', | |
| type=str, | |
| default='pascal', | |
| choices=['pascal', 'coco', 'fss', 'c2p'] | |
| ) | |
| parser.add_argument( | |
| '--datapath', | |
| type=str, | |
| default='fewshot_data/Datasets_HSN' | |
| ) | |
| parser.add_argument( | |
| "--activation", | |
| choices=['relu', 'lrelu', 'tanh'], | |
| default="relu", | |
| help="use which activation to activate the block", | |
| ) | |
| self.parser = parser | |
| def parse(self): | |
| args = self.parser.parse_args() | |
| args.cuda = not args.no_cuda and torch.cuda.is_available() | |
| print(args) | |
| return args | |
| def test(args): | |
| module_def = LSegModuleZS | |
| module = module_def.load_from_checkpoint( | |
| checkpoint_path=args.weights, | |
| data_path=args.datapath, | |
| dataset=args.dataset, | |
| backbone=args.backbone, | |
| aux=args.aux, | |
| num_features=256, | |
| aux_weight=0, | |
| se_loss=False, | |
| se_weight=0, | |
| base_lr=0, | |
| batch_size=1, | |
| max_epochs=0, | |
| ignore_index=args.ignore_index, | |
| dropout=0.0, | |
| scale_inv=args.scale_inv, | |
| augment=False, | |
| no_batchnorm=False, | |
| widehead=args.widehead, | |
| widehead_hr=args.widehead_hr, | |
| map_locatin="cpu", | |
| arch_option=args.arch_option, | |
| use_pretrained=args.use_pretrained, | |
| strict=args.strict, | |
| logpath='fewshot/logpath_4T/', | |
| fold=args.fold, | |
| block_depth=0, | |
| nshot=args.nshot, | |
| finetune_mode=False, | |
| activation=args.activation, | |
| ) | |
| Evaluator.initialize() | |
| if args.backbone in ["clip_resnet101"]: | |
| FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False, imagenet_norm=True) | |
| else: | |
| FSSDataset.initialize(img_size=480, datapath=args.datapath, use_original_imgsize=False) | |
| # dataloader | |
| args.benchmark = args.dataset | |
| dataloader = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) | |
| model = module.net.eval().cuda() | |
| # model = module.net.model.cpu() | |
| print(model) | |
| scales = ( | |
| [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] | |
| if args.dataset == "citys" | |
| else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | |
| ) | |
| f = open("logs/fewshot/log_fewshot-test_nshot{}_{}.txt".format(args.nshot, args.dataset), "a+") | |
| utils.fix_randseed(0) | |
| average_meter = AverageMeter(dataloader.dataset) | |
| for idx, batch in enumerate(dataloader): | |
| batch = utils.to_cuda(batch) | |
| image = batch['query_img'] | |
| target = batch['query_mask'] | |
| class_info = batch['class_id'] | |
| # pred_mask = evaluator.parallel_forward(image, class_info) | |
| pred_mask = model(image, class_info) | |
| # assert pred_mask.argmax(dim=1).size() == batch['query_mask'].size() | |
| # 2. Evaluate prediction | |
| if args.benchmark == 'pascal' and batch['query_ignore_idx'] is not None: | |
| query_ignore_idx = batch['query_ignore_idx'] | |
| area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target, query_ignore_idx) | |
| else: | |
| area_inter, area_union = Evaluator.classify_prediction(pred_mask.argmax(dim=1), target) | |
| average_meter.update(area_inter, area_union, class_info, loss=None) | |
| average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) | |
| # Write evaluation results | |
| average_meter.write_result('Test', 0) | |
| test_miou, test_fb_iou = average_meter.compute_iou() | |
| Logger.info('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) | |
| Logger.info('==================== Finished Testing ====================') | |
| f.write('{}\n'.format(args.weights)) | |
| f.write('Fold %d, %d-shot ==> mIoU: %5.2f \t FB-IoU: %5.2f\n' % (args.fold, args.nshot, test_miou.item(), test_fb_iou.item())) | |
| f.close() | |
| if __name__ == "__main__": | |
| args = Options().parse() | |
| torch.manual_seed(args.seed) | |
| test(args) | |