In [1]:
config_sdxs = {
    # === Основные размеры и каналы ===
    "in_channels": 16,               # Количество входных каналов (совместимость с VAE)
    "out_channels": 16,              # Количество выходных каналов (симметрично in_channels)          

    # === Cross-Attention ===
    "cross_attention_dim": 1024,      # Размерность текстовых эмбеддингов
    "use_linear_projection": True,
    "norm_num_groups": 32,
    
    # === Архитектура блоков ===
    "down_block_types": [ # энкодер
        "DownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
    ],
    "up_block_types": [   # декодер
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "UpBlock2D",
    ],

    # === Конфигурация каналов ===
    "block_out_channels": [256, 512, 1024, 1024],

    "transformer_layers_per_block": [1, 1, 1, 8],
    "attention_head_dim": [4, 8, 16, 16],
}

def check_initialization(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: mean={param.data.mean():.3f}, std={param.data.std():.3f}")


if 1:
    checkpoint_path = "/workspace/sdxs3d/butterfly"#"sdxs"
    import torch
    from diffusers import UNet2DConditionModel
    print("test unet")
    new_unet = UNet2DConditionModel(**config_sdxs).to("cuda", dtype=torch.float16)
    #new_unet = UNet2DConditionModel().to("cuda", dtype=torch.float16)

    # После инициализации
    #check_initialization(new_unet)

    #assert all(ch % 32 == 0 for ch in new_unet.config["block_out_channels"]), "Каналы должны быть кратны 32"
    num_params = sum(p.numel() for p in new_unet.parameters())
    print(f"Количество параметров: {num_params}")

    # Генерация тестового латента (640x512 в latent space)
    test_latent = torch.randn(1, 16, 60, 48).to("cuda", dtype=torch.float16)  # 60x48 ≈ 512px
    timesteps = torch.tensor([1]).to("cuda", dtype=torch.float16)
    encoder_hidden_states = torch.randn(1, 77, 1024).to("cuda", dtype=torch.float16)
    
    with torch.no_grad():
        output = new_unet(
            test_latent, 
            timesteps, 
            encoder_hidden_states
        ).sample

    print(f"Output shape: {output.shape}")
    new_unet.save_pretrained(checkpoint_path)
    print(new_unet) 

test unet
Количество параметров: 1546186256
Output shape: torch.Size([1, 16, 60, 48])
UNet2DConditionModel(
  (conv_in): Conv2d(16, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=256, out_features=1024, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)
          (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
      