Spaces:
Runtime error
Runtime error
| from diffusers.schedulers.scheduling_ddpm import DDPMScheduler | |
| from src import (ContentEncoder, | |
| StyleEncoder, | |
| UNet) | |
| def build_unet(args): | |
| unet = UNet( | |
| sample_size=args.resolution, | |
| in_channels=3, | |
| out_channels=3, | |
| flip_sin_to_cos=True, | |
| freq_shift=0, | |
| down_block_types=('DownBlock2D', | |
| 'MCADownBlock2D', | |
| 'MCADownBlock2D', | |
| 'DownBlock2D'), | |
| up_block_types=('UpBlock2D', | |
| 'StyleRSIUpBlock2D', | |
| 'StyleRSIUpBlock2D', | |
| 'UpBlock2D'), | |
| block_out_channels=args.unet_channels, | |
| layers_per_block=2, | |
| downsample_padding=1, | |
| mid_block_scale_factor=1, | |
| act_fn='silu', | |
| norm_num_groups=32, | |
| norm_eps=1e-05, | |
| cross_attention_dim=args.style_start_channel * 16, | |
| attention_head_dim=1, | |
| channel_attn=args.channel_attn, | |
| content_encoder_downsample_size=args.content_encoder_downsample_size, | |
| content_start_channel=args.content_start_channel, | |
| reduction=32) | |
| return unet | |
| def build_style_encoder(args): | |
| style_image_encoder = StyleEncoder( | |
| G_ch=args.style_start_channel, | |
| resolution=args.style_image_size[0]) | |
| print("Get CG-GAN Style Encoder!") | |
| return style_image_encoder | |
| def build_content_encoder(args): | |
| content_image_encoder = ContentEncoder( | |
| G_ch=args.content_start_channel, | |
| resolution=args.content_image_size[0]) | |
| print("Get CG-GAN Content Encoder!") | |
| return content_image_encoder | |
| def build_ddpm_scheduler(args): | |
| ddpm_scheduler = DDPMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.0001, | |
| beta_end=0.02, | |
| beta_schedule=args.beta_scheduler, | |
| trained_betas=None, | |
| variance_type="fixed_small", | |
| clip_sample=True) | |
| return ddpm_scheduler |