frankenstein / tutorial5_new_net.py
zy7_oldserver
1
fd601de
from generative.inferers import ControlNetDiffusionInferer, DiffusionInferer
from generative.networks.nets import DiffusionModelUNet, ControlNet
from generative.networks.schedulers import DDPMScheduler
import torch
from synthrad_conversion.networks.ddpm.diffusion_unet_modality import DiffusionModelUNet_modality
device = torch.device("cuda:1")
model = DiffusionModelUNet_modality(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(128, 256, 256),
attention_levels=(False, True, True),
num_res_blocks=1,
num_head_channels=32,
norm_num_groups=8,
num_modalities=6,
modality_emb_dim=32,
with_conditioning=True,
cross_attention_dim=32,
)
model.to(device)
# test the input and output of this model
model.eval()
# Batch大小
batch_size = 2
H, W = 128, 128 # 2D图像大小
# 输入图像
x = torch.randn(batch_size, 1, H, W).to(device) # (N, C=1, H, W)
# 时间步
timesteps = torch.randint(0, 1000, (batch_size,), dtype=torch.long).to(device)
# source modality labels 和 target modality labels
source_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device)
target_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device)
print('timesteps shape', timesteps.shape)
print('input label shape', source_modality_labels.shape)
print('input label', source_modality_labels)
print('target label', target_modality_labels)
# context (可以是None,因为我们在forward里自己根据modality生成)
context = None
# class_labels (如果你模型有开num_class_embeds的话,要准备;如果没开,就不用)
class_labels = None
with torch.no_grad():
output = model(
x=x,
timesteps=timesteps,
context=context,
class_labels=class_labels,
source_modality_labels=source_modality_labels,
target_modality_labels=target_modality_labels,
)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")