Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import torch | |
| import model | |
| from util import util | |
| class BaseOptions(): | |
| def __init__(self): | |
| self.parser = argparse.ArgumentParser() | |
| self.initialized = False | |
| def initialize(self, parser): | |
| # base define | |
| parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment.') | |
| parser.add_argument('--model', type=str, default='tc', help='name of the model type. [pluralistic]') | |
| parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are save here') | |
| parser.add_argument('--which_iter', type=int, default='0', help='which iterations to load') | |
| parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load') | |
| parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0, 1, 2 use -1 for CPU') | |
| # data define | |
| parser.add_argument('--mask_type', type=int, default=[0,1,3], help='0:center,1:regular,2:irregular,3:external') | |
| parser.add_argument('--img_file', type=str, default='/data/dataset/train', help='training and testing dataset') | |
| parser.add_argument('--mask_file', type=str, default='none', help='load test mask') | |
| parser.add_argument('--img_nc', type=int, default=3, help='# of image channels') | |
| parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='preprocessing image at load time') | |
| parser.add_argument('--load_size', type=int, default=542, help='scale examples to this size') | |
| parser.add_argument('--fine_size', type=int, default=512, help='then crop to this size') | |
| parser.add_argument('--fixed_size', type=int, default=256, help='fixed the image size in S1 with transformer') | |
| parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the image') | |
| parser.add_argument('--data_powers', type=int, default=5, help='# times of the scale to 2 times') | |
| parser.add_argument('--reverse_mask', action='store_true', help='if specified, random reverse the mask region') | |
| parser.add_argument('--batch_size', type=int, default=8, help='input batch size') | |
| parser.add_argument('--nThreads', type=int, default=8, help='# threads for loading data') | |
| parser.add_argument('--no_shuffle', action='store_true', help='if true, takes examples serial') | |
| # display parameter define | |
| parser.add_argument('--display_winsize', type=int, default=256, help='display window size') | |
| parser.add_argument('--display_id', type=int, default=None, help='display id of the web') | |
| parser.add_argument('--display_server', type=str, default="http://localhost", help='server of the web display') | |
| parser.add_argument('--display_env', type=str, default='main', help='display name (default is "main")') | |
| parser.add_argument('--display_port', type=int, default=8092, help='port of the web display') | |
| parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all examples in a single visidom web panel') | |
| # Encoder-Decoder define | |
| parser.add_argument('--ngf', type=int, default=32, help='# of gen filters in the last conv layer') | |
| parser.add_argument('--ndf', type=int, default=32, help='# of dis filters in the first conv layer') | |
| parser.add_argument('--num_res_blocks', type=int, default=2, help='# of residual block in the encoder and decoder layer') | |
| parser.add_argument('--netD', type=str, default='style', help='specify discriminator architecture ') | |
| parser.add_argument('--netG', type=str, default='diff', help='specify decoder architecture') | |
| parser.add_argument('--netE', type=str, default='diff', help='specify encoder architecture') | |
| parser.add_argument('--kernel_G', type=int, default=3, help='kernel size for the decoder') | |
| parser.add_argument('--kernel_E', type=int, default=1, help='kernel size for the encoder') | |
| parser.add_argument('--add_noise', action='store_true', help='if true, add noise to the decoder') | |
| parser.add_argument('--attn_E', action='store_true', help='if true, use attention in the encoder') | |
| parser.add_argument('--attn_G', action='store_true', help='if true, use attention in the decoder') | |
| parser.add_argument('--attn_D', action='store_true', help='if true, use attention in the decoder') | |
| parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') | |
| parser.add_argument('--n_layers_G', type=int, default=4, help='# of down sample layers in the Encoder and Decoder') | |
| parser.add_argument('--norm', type=str, default='pixel', help='instance normalization or batch normalization [instance | batch | pixel | none]') | |
| parser.add_argument('--activation', type=str, default='leakyrelu', help='activation layer [relu | gelu | leakyrelu | none]') | |
| parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]') | |
| parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') | |
| parser.add_argument('--lipip_path', type=str, default='./model/lpips/vgg.pth', help='the pretrained LIPPS model') | |
| # Transformer define | |
| parser.add_argument('--netT', type=str, default='original', help='specify transformer architecture') | |
| parser.add_argument('--embed_dim', type=int, default=512, help='the numbers of embedding dimension') | |
| parser.add_argument('--dropout', type=float, default=0., help='the dropout probability in transformer') | |
| parser.add_argument('--kernel_T', type=int, default=1, help='kernel size for the transformer projection') | |
| parser.add_argument('--n_encoders', type=int, default=12, help='the numbers of encoder in transformer') | |
| parser.add_argument('--n_decoders', type=int, default=0, help='the numbers of decoder in transformer') | |
| parser.add_argument('--embed_type', type=str, default='learned', choices=['learned', 'sine']) | |
| parser.add_argument('--top_k', type=int, default=10, help='sample the results on top k value') | |
| # VQ define | |
| parser.add_argument('--num_embeds', type=int, default=1024, help='the numbers of words for image') | |
| parser.add_argument('--use_pos_G', action='store_true', help='if true, position embedding in G') | |
| parser.add_argument('--word_size', type=int, default=16, help='the numbers of word for each image') | |
| self.initialized = True | |
| return parser | |
| def gather_options(self): | |
| """Add additional model-specific options""" | |
| if not self.initialized: | |
| parser = self.initialize(self.parser) | |
| # get basic options | |
| opt, _ = parser.parse_known_args() | |
| # modify the options for different models | |
| model_option_set = model.get_option_setter(opt.model) | |
| parser = model_option_set(parser, self.isTrain) | |
| opt = parser.parse_args() | |
| return opt | |
| def parse(self): | |
| """Parse the options""" | |
| opt = self.gather_options() | |
| opt.isTrain = self.isTrain | |
| self.print_options(opt) | |
| # set gpu ids | |
| str_ids = opt.gpu_ids.split(',') | |
| opt.gpu_ids = [] | |
| for str_id in str_ids: | |
| id = int(str_id) | |
| if id >= 0: | |
| opt.gpu_ids.append(id) | |
| if len(opt.gpu_ids): | |
| torch.cuda.set_device(opt.gpu_ids[0]) | |
| self.opt = opt | |
| return self.opt | |
| def print_options(opt): | |
| """print and save options""" | |
| print('--------------Options--------------') | |
| for k, v in sorted(vars(opt).items()): | |
| print('%s: %s' % (str(k), str(v))) | |
| print('----------------End----------------') | |
| # save to the disk | |
| expr_dir = os.path.join(opt.checkpoints_dir, opt.name) | |
| util.mkdirs(expr_dir) | |
| if opt.isTrain: | |
| file_name = os.path.join(expr_dir, 'train_opt.txt') | |
| else: | |
| file_name = os.path.join(expr_dir, 'test_opt.txt') | |
| with open(file_name, 'wt') as opt_file: | |
| opt_file.write('--------------Options--------------\n') | |
| for k, v in sorted(vars(opt).items()): | |
| opt_file.write('%s: %s\n' % (str(k), str(v))) | |
| opt_file.write('----------------End----------------\n') |