Spaces:
Runtime error
Runtime error
| from generative.inferers import ControlNetDiffusionInferer, DiffusionInferer | |
| from generative.networks.nets import DiffusionModelUNet, ControlNet | |
| from generative.networks.schedulers import DDPMScheduler | |
| import torch | |
| from synthrad_conversion.networks.ddpm.diffusion_unet_modality import DiffusionModelUNet_modality | |
| device = torch.device("cuda:1") | |
| model = DiffusionModelUNet_modality( | |
| spatial_dims=2, | |
| in_channels=1, | |
| out_channels=1, | |
| num_channels=(128, 256, 256), | |
| attention_levels=(False, True, True), | |
| num_res_blocks=1, | |
| num_head_channels=32, | |
| norm_num_groups=8, | |
| num_modalities=6, | |
| modality_emb_dim=32, | |
| with_conditioning=True, | |
| cross_attention_dim=32, | |
| ) | |
| model.to(device) | |
| # test the input and output of this model | |
| model.eval() | |
| # Batch大小 | |
| batch_size = 2 | |
| H, W = 128, 128 # 2D图像大小 | |
| # 输入图像 | |
| x = torch.randn(batch_size, 1, H, W).to(device) # (N, C=1, H, W) | |
| # 时间步 | |
| timesteps = torch.randint(0, 1000, (batch_size,), dtype=torch.long).to(device) | |
| # source modality labels 和 target modality labels | |
| source_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device) | |
| target_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device) | |
| print('timesteps shape', timesteps.shape) | |
| print('input label shape', source_modality_labels.shape) | |
| print('input label', source_modality_labels) | |
| print('target label', target_modality_labels) | |
| # context (可以是None,因为我们在forward里自己根据modality生成) | |
| context = None | |
| # class_labels (如果你模型有开num_class_embeds的话,要准备;如果没开,就不用) | |
| class_labels = None | |
| with torch.no_grad(): | |
| output = model( | |
| x=x, | |
| timesteps=timesteps, | |
| context=context, | |
| class_labels=class_labels, | |
| source_modality_labels=source_modality_labels, | |
| target_modality_labels=target_modality_labels, | |
| ) | |
| print(f"Input shape: {x.shape}") | |
| print(f"Output shape: {output.shape}") |