ByteDream / debug_unet.py
Enzo8930302's picture
Upload folder using huggingface_hub
0e3999b verified
raw
history blame contribute delete
889 Bytes
"""Debug UNet channel dimensions"""
from bytedream.model import UNet2DConditionModel
unet = UNet2DConditionModel()
print("Block out channels:", unet.block_out_channels)
print("\nDown blocks:")
for i, block in enumerate(unet.down_blocks):
print(f" Down {i}: {len(block.resnets)} resnets")
print("\nUp blocks:")
reversed_block_out_channels = list(reversed(unet.block_out_channels))
for i, block in enumerate(unet.up_blocks):
in_channels = unet.block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1]
output_channel = reversed_block_out_channels[i]
skip_channels = reversed_block_out_channels[min(i + 1, len(unet.block_out_channels) - 1)]
print(f" Up {i}: in={in_channels}, out={output_channel}, skips={skip_channels}")
print(f" ResNets expect: {[block.resnets[j].conv1.in_channels for j in range(len(block.resnets))]}")