File size: 889 Bytes
0e3999b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Debug UNet channel dimensions"""
from bytedream.model import UNet2DConditionModel

unet = UNet2DConditionModel()

print("Block out channels:", unet.block_out_channels)
print("\nDown blocks:")
for i, block in enumerate(unet.down_blocks):
    print(f"  Down {i}: {len(block.resnets)} resnets")
    
print("\nUp blocks:")
reversed_block_out_channels = list(reversed(unet.block_out_channels))
for i, block in enumerate(unet.up_blocks):
    in_channels = unet.block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1]
    output_channel = reversed_block_out_channels[i]
    skip_channels = reversed_block_out_channels[min(i + 1, len(unet.block_out_channels) - 1)]
    print(f"  Up {i}: in={in_channels}, out={output_channel}, skips={skip_channels}")
    print(f"    ResNets expect: {[block.resnets[j].conv1.in_channels for j in range(len(block.resnets))]}")