File size: 1,981 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from generative.inferers import ControlNetDiffusionInferer, DiffusionInferer
from generative.networks.nets import DiffusionModelUNet, ControlNet
from generative.networks.schedulers import DDPMScheduler
import torch
from synthrad_conversion.networks.ddpm.diffusion_unet_modality import DiffusionModelUNet_modality
device = torch.device("cuda:1")

model = DiffusionModelUNet_modality(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 256),
    attention_levels=(False, True, True),
    num_res_blocks=1,
    num_head_channels=32,
    norm_num_groups=8,
    num_modalities=6,
    modality_emb_dim=32,
    with_conditioning=True,
    cross_attention_dim=32,
)
model.to(device)

# test the input and output of this model
model.eval() 
# Batch大小
batch_size = 2
H, W = 128, 128  # 2D图像大小

# 输入图像
x = torch.randn(batch_size, 1, H, W).to(device)  # (N, C=1, H, W)

# 时间步
timesteps = torch.randint(0, 1000, (batch_size,), dtype=torch.long).to(device)

# source modality labels 和 target modality labels
source_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device)
target_modality_labels = torch.randint(0, 6, (batch_size,1,), dtype=torch.long).to(device)

print('timesteps shape', timesteps.shape)
print('input label shape', source_modality_labels.shape)
print('input label', source_modality_labels)
print('target label', target_modality_labels)
# context (可以是None,因为我们在forward里自己根据modality生成)
context = None

# class_labels (如果你模型有开num_class_embeds的话,要准备;如果没开,就不用)
class_labels = None


with torch.no_grad():
    output = model(
        x=x,
        timesteps=timesteps,
        context=context,
        class_labels=class_labels,
        source_modality_labels=source_modality_labels,
        target_modality_labels=target_modality_labels,
    )

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")