"""Debug UNet channel dimensions""" from bytedream.model import UNet2DConditionModel unet = UNet2DConditionModel() print("Block out channels:", unet.block_out_channels) print("\nDown blocks:") for i, block in enumerate(unet.down_blocks): print(f" Down {i}: {len(block.resnets)} resnets") print("\nUp blocks:") reversed_block_out_channels = list(reversed(unet.block_out_channels)) for i, block in enumerate(unet.up_blocks): in_channels = unet.block_out_channels[-1] if i == 0 else reversed_block_out_channels[i - 1] output_channel = reversed_block_out_channels[i] skip_channels = reversed_block_out_channels[min(i + 1, len(unet.block_out_channels) - 1)] print(f" Up {i}: in={in_channels}, out={output_channel}, skips={skip_channels}") print(f" ResNets expect: {[block.resnets[j].conv1.in_channels for j in range(len(block.resnets))]}")