Spaces:
Runtime error
Runtime error
| import re | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from argparse import ArgumentParser | |
| import pytorch_lightning as pl | |
| from .lsegmentation_module_zs import LSegmentationModuleZS | |
| from .models.lseg_net_zs import LSegNetZS, LSegRNNetZS | |
| from encoding.models.sseg.base import up_kwargs | |
| import os | |
| import clip | |
| import numpy as np | |
| from scipy import signal | |
| import glob | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| class LSegModuleZS(LSegmentationModuleZS): | |
| def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): | |
| super(LSegModuleZS, self).__init__( | |
| data_path, dataset, batch_size, base_lr, max_epochs, **kwargs | |
| ) | |
| label_list = self.get_labels(dataset) | |
| self.len_dataloader = len(label_list) | |
| # print(kwargs) | |
| if kwargs["use_pretrained"] in ['False', False]: | |
| use_pretrained = False | |
| elif kwargs["use_pretrained"] in ['True', True]: | |
| use_pretrained = True | |
| if kwargs["backbone"] in ["clip_resnet101"]: | |
| self.net = LSegRNNetZS( | |
| label_list=label_list, | |
| backbone=kwargs["backbone"], | |
| features=kwargs["num_features"], | |
| aux=kwargs["aux"], | |
| use_pretrained=use_pretrained, | |
| arch_option=kwargs["arch_option"], | |
| block_depth=kwargs["block_depth"], | |
| activation=kwargs["activation"], | |
| ) | |
| else: | |
| self.net = LSegNetZS( | |
| label_list=label_list, | |
| backbone=kwargs["backbone"], | |
| features=kwargs["num_features"], | |
| aux=kwargs["aux"], | |
| use_pretrained=use_pretrained, | |
| arch_option=kwargs["arch_option"], | |
| block_depth=kwargs["block_depth"], | |
| activation=kwargs["activation"], | |
| ) | |
| def get_labels(self, dataset): | |
| labels = [] | |
| path = 'label_files/fewshot_{}.txt'.format(dataset) | |
| assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) | |
| f = open(path, 'r') | |
| lines = f.readlines() | |
| for line in lines: | |
| label = line.strip() | |
| labels.append(label) | |
| f.close() | |
| print(labels) | |
| return labels | |
| def add_model_specific_args(parent_parser): | |
| parser = LSegmentationModuleZS.add_model_specific_args(parent_parser) | |
| parser = ArgumentParser(parents=[parser]) | |
| parser.add_argument( | |
| "--backbone", | |
| type=str, | |
| default="vitb16_384", | |
| help="backbone network", | |
| ) | |
| parser.add_argument( | |
| "--num_features", | |
| type=int, | |
| default=256, | |
| help="number of featurs that go from encoder to decoder", | |
| ) | |
| parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") | |
| parser.add_argument( | |
| "--finetune_weights", type=str, help="load weights to finetune from" | |
| ) | |
| parser.add_argument( | |
| "--no-scaleinv", | |
| default=True, | |
| action="store_false", | |
| help="turn off scaleinv layers", | |
| ) | |
| parser.add_argument( | |
| "--no-batchnorm", | |
| default=False, | |
| action="store_true", | |
| help="turn off batchnorm", | |
| ) | |
| parser.add_argument( | |
| "--widehead", default=False, action="store_true", help="wider output head" | |
| ) | |
| parser.add_argument( | |
| "--widehead_hr", | |
| default=False, | |
| action="store_true", | |
| help="wider output head", | |
| ) | |
| parser.add_argument( | |
| "--use_pretrained", | |
| type=str, | |
| default="True", | |
| help="whether use the default model to intialize the model", | |
| ) | |
| parser.add_argument( | |
| "--arch_option", | |
| type=int, | |
| default=0, | |
| help="which kind of architecture to be used", | |
| ) | |
| parser.add_argument( | |
| "--block_depth", | |
| type=int, | |
| default=0, | |
| help="how many blocks should be used", | |
| ) | |
| parser.add_argument( | |
| "--activation", | |
| choices=['relu', 'lrelu', 'tanh'], | |
| default="relu", | |
| help="use which activation to activate the block", | |
| ) | |
| return parser | |