| | |
| | |
| | from argparse import ArgumentParser |
| | import sys |
| | import os |
| |
|
| | sys.path.append('..') |
| | sys.path.append('.') |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.utils.data.distributed import DistributedSampler |
| |
|
| | from vit.vision_transformer import VisionTransformer as ViT |
| | from vit.vit_triplane import ViTTriplane |
| | from guided_diffusion import dist_util, logger |
| |
|
| | import click |
| | import dnnlib |
| |
|
| | SEED = 42 |
| | BATCH_SIZE = 8 |
| | NUM_EPOCHS = 1 |
| |
|
| |
|
| | class YourDataset(Dataset): |
| | def __init__(self): |
| | pass |
| |
|
| |
|
| | @click.command() |
| | @click.option('--cfg', help='Base configuration', type=str, default='ffhq') |
| | @click.option('--sr-module', |
| | help='Superresolution module override', |
| | metavar='STR', |
| | required=False, |
| | default=None) |
| | @click.option('--density_reg', |
| | help='Density regularization strength.', |
| | metavar='FLOAT', |
| | type=click.FloatRange(min=0), |
| | default=0.25, |
| | required=False, |
| | show_default=True) |
| | @click.option('--density_reg_every', |
| | help='lazy density reg', |
| | metavar='int', |
| | type=click.FloatRange(min=1), |
| | default=4, |
| | required=False, |
| | show_default=True) |
| | @click.option('--density_reg_p_dist', |
| | help='density regularization strength.', |
| | metavar='FLOAT', |
| | type=click.FloatRange(min=0), |
| | default=0.004, |
| | required=False, |
| | show_default=True) |
| | @click.option('--reg_type', |
| | help='Type of regularization', |
| | metavar='STR', |
| | type=click.Choice([ |
| | 'l1', 'l1-alt', 'monotonic-detach', 'monotonic-fixed', |
| | 'total-variation' |
| | ]), |
| | required=False, |
| | default='l1') |
| | @click.option('--decoder_lr_mul', |
| | help='decoder learning rate multiplier.', |
| | metavar='FLOAT', |
| | type=click.FloatRange(min=0), |
| | default=1, |
| | required=False, |
| | show_default=True) |
| | @click.option('--c_scale', |
| | help='Scale factor for generator pose conditioning.', |
| | metavar='FLOAT', |
| | type=click.FloatRange(min=0), |
| | required=False, |
| | default=1) |
| | def main(**kwargs): |
| | |
| | |
| | |
| |
|
| | opts = dnnlib.EasyDict(kwargs) |
| | c = dnnlib.EasyDict() |
| |
|
| | rendering_options = { |
| | |
| | 'image_resolution': 256, |
| | 'disparity_space_sampling': False, |
| | 'clamp_mode': 'softplus', |
| | |
| | |
| | |
| | |
| | 'c_scale': |
| | opts.c_scale, |
| | |
| | |
| | 'density_reg': opts.density_reg, |
| | 'density_reg_p_dist': opts. |
| | density_reg_p_dist, |
| | 'reg_type': opts. |
| | reg_type, |
| | 'decoder_lr_mul': |
| | opts.decoder_lr_mul, |
| | 'sr_antialias': True, |
| | 'return_triplane_features': True, |
| | 'return_sampling_details_flag': True, |
| | } |
| |
|
| | if opts.cfg == 'ffhq': |
| | rendering_options.update({ |
| | 'focal': 2985.29 / 700, |
| | 'depth_resolution': |
| | |
| | 36, |
| | 'depth_resolution_importance': |
| | |
| | 36, |
| | 'ray_start': |
| | 2.25, |
| | 'ray_end': |
| | 3.3, |
| | 'box_warp': |
| | 1, |
| | 'avg_camera_radius': |
| | 2.7, |
| | 'avg_camera_pivot': [ |
| | 0, 0, 0.2 |
| | ], |
| | }) |
| | elif opts.cfg == 'afhq': |
| | rendering_options.update({ |
| | 'focal': 4.2647, |
| | 'depth_resolution': 48, |
| | 'depth_resolution_importance': 48, |
| | 'ray_start': 2.25, |
| | 'ray_end': 3.3, |
| | 'box_warp': 1, |
| | 'avg_camera_radius': 2.7, |
| | 'avg_camera_pivot': [0, 0, -0.06], |
| | }) |
| | elif opts.cfg == 'shapenet': |
| | rendering_options.update({ |
| | 'depth_resolution': 64, |
| | 'depth_resolution_importance': 64, |
| | |
| | |
| | 'ray_start': 0.1, |
| | 'ray_end': 3.3, |
| | 'box_warp': 1.6, |
| | 'white_back': True, |
| | 'avg_camera_radius': 1.7, |
| | 'avg_camera_pivot': [0, 0, 0], |
| | }) |
| | else: |
| | assert False, "Need to specify config" |
| |
|
| | c.rendering_kwargs = rendering_options |
| |
|
| | args = opts |
| |
|
| | |
| | args.local_rank = int(os.environ["LOCAL_RANK"]) |
| | args.is_master = args.local_rank == 0 |
| |
|
| | |
| | |
| | device = torch.device(f"cuda:{args.local_rank}") |
| |
|
| | |
| | dist.init_process_group(backend='nccl', |
| | init_method='env://', |
| | rank=args.local_rank, |
| | world_size=torch.cuda.device_count()) |
| | print(f"{args.local_rank=} init complete") |
| | torch.cuda.set_device(args.local_rank) |
| |
|
| | |
| | torch.cuda.manual_seed_all(SEED) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | model = ViTTriplane( |
| | img_size=[224], |
| | patch_size=16, |
| | in_chans=384, |
| | num_classes=0, |
| | embed_dim=384, |
| | depth=2, |
| | num_heads=16, |
| | mlp_ratio=4., |
| | qkv_bias=False, |
| | qk_scale=None, |
| | drop_rate=0.1, |
| | attn_drop_rate=0., |
| | drop_path_rate=0., |
| | norm_layer=nn.LayerNorm, |
| | out_chans=96, |
| | c_dim=25, |
| | img_resolution=128, |
| | img_channels=3, |
| | cls_token=False, |
| | |
| | rendering_kwargs=c.rendering_kwargs, |
| | ) |
| | |
| |
|
| | |
| | model = model.to(device) |
| |
|
| | |
| | model = DDP(model, |
| | device_ids=[args.local_rank], |
| | output_device=args.local_rank) |
| |
|
| | dist_util.sync_params(model.named_parameters()) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | for epoch in range(NUM_EPOCHS): |
| | |
| | model.train() |
| |
|
| | |
| | dist.barrier() |
| |
|
| | noise = torch.randn(1, 14 * 14, 384).to(device) |
| | img = model(noise, torch.zeros(1, 25).to(device)) |
| | print(img['image'].shape) |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|