| """Debug UNet channel dimensions""" | |
| from bytedream.model import UNet2DConditionModel | |
| unet = UNet2DConditionModel() | |
| print("Block out channels:", unet.block_out_channels) | |
| print("\nDown blocks:") | |
| for i, block in enumerate(unet.down_blocks): | |
| print(f" Down {i}: {len(block.resnets)} resnets") | |
| print("\nUp blocks:") | |
| reversed_block_out_channels = list(reversed(unet.block_out_channels)) | |
| for i, block in enumerate(unet.up_blocks): | |
| in_channels = unet.block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1] | |
| output_channel = reversed_block_out_channels[i] | |
| skip_channels = reversed_block_out_channels[min(i + 1, len(unet.block_out_channels) - 1)] | |
| print(f" Up {i}: in={in_channels}, out={output_channel}, skips={skip_channels}") | |
| print(f" ResNets expect: {[block.resnets[j].conv1.in_channels for j in range(len(block.resnets))]}") | |