{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5212f806-14b4-4b5f-bcb4-09e36df3b7d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test unet\n", "Количество параметров: 1546186256\n", "Output shape: torch.Size([1, 16, 60, 48])\n", "UNet2DConditionModel(\n", " (conv_in): Conv2d(16, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_proj): Timesteps()\n", " (time_embedding): TimestepEmbedding(\n", " (linear_1): Linear(in_features=256, out_features=1024, bias=True)\n", " (act): SiLU()\n", " (linear_2): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " (down_blocks): ModuleList(\n", " (0): DownBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0): BasicTransformerBlock(\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=512, out_features=512, bias=False)\n", " (to_k): Linear(in_features=512, out_features=512, bias=False)\n", " (to_v): Linear(in_features=512, out_features=512, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=512, out_features=512, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=512, out_features=4096, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=2048, out_features=512, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n", " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n", " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0): BasicTransformerBlock(\n", " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): CrossAttnDownBlock2D(\n", " (attentions): ModuleList(\n", " (0-1): 2 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0-7): 8 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (up_blocks): ModuleList(\n", " (0): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0-7): 8 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-2): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n", " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (1): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0): BasicTransformerBlock(\n", " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n", " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (2): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1536, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1536, 1024, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): CrossAttnUpBlock2D(\n", " (attentions): ModuleList(\n", " (0-2): 3 x Transformer2DModel(\n", " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0): BasicTransformerBlock(\n", " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=512, out_features=512, bias=False)\n", " (to_k): Linear(in_features=512, out_features=512, bias=False)\n", " (to_v): Linear(in_features=512, out_features=512, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=512, out_features=512, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=512, out_features=4096, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=2048, out_features=512, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1536, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n", " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n", " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (2): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n", " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n", " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): UpBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n", " (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n", " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2DCrossAttn(\n", " (attentions): ModuleList(\n", " (0): Transformer2DModel(\n", " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n", " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n", " (transformer_blocks): ModuleList(\n", " (0-7): 8 x BasicTransformerBlock(\n", " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn1): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (attn2): Attention(\n", " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=1024, out_features=1024, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", " (ff): FeedForward(\n", " (net): ModuleList(\n", " (0): GEGLU(\n", " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n", " )\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=4096, out_features=1024, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 256, eps=1e-05, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", ")\n" ] } ], "source": [ "config_sdxs = {\n", " # === Основные размеры и каналы ===\n", " \"in_channels\": 16, # Количество входных каналов (совместимость с VAE)\n", " \"out_channels\": 16, # Количество выходных каналов (симметрично in_channels) \n", "\n", " # === Cross-Attention ===\n", " \"cross_attention_dim\": 1024, # Размерность текстовых эмбеддингов\n", " \"use_linear_projection\": True,\n", " \"norm_num_groups\": 32,\n", " \n", " # === Архитектура блоков ===\n", " \"down_block_types\": [ # энкодер\n", " \"DownBlock2D\",\n", " \"CrossAttnDownBlock2D\",\n", " \"CrossAttnDownBlock2D\",\n", " \"CrossAttnDownBlock2D\",\n", " ],\n", " \"up_block_types\": [ # декодер\n", " \"CrossAttnUpBlock2D\",\n", " \"CrossAttnUpBlock2D\",\n", " \"CrossAttnUpBlock2D\",\n", " \"UpBlock2D\",\n", " ],\n", "\n", " # === Конфигурация каналов ===\n", " \"block_out_channels\": [256, 512, 1024, 1024],\n", "\n", " \"transformer_layers_per_block\": [1, 1, 1, 8],\n", " \"attention_head_dim\": [4, 8, 16, 16],\n", "}\n", "\n", "def check_initialization(model):\n", " for name, param in model.named_parameters():\n", " if param.requires_grad:\n", " print(f\"{name}: mean={param.data.mean():.3f}, std={param.data.std():.3f}\")\n", "\n", "\n", "if 1:\n", " checkpoint_path = \"/workspace/sdxs3d/butterfly\"#\"sdxs\"\n", " import torch\n", " from diffusers import UNet2DConditionModel\n", " print(\"test unet\")\n", " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n", " #new_unet = UNet2DConditionModel().to(\"cuda\", dtype=torch.float16)\n", "\n", " # После инициализации\n", " #check_initialization(new_unet)\n", "\n", " #assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n", " num_params = sum(p.numel() for p in new_unet.parameters())\n", " print(f\"Количество параметров: {num_params}\")\n", "\n", " # Генерация тестового латента (640x512 в latent space)\n", " test_latent = torch.randn(1, 16, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n", " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n", " encoder_hidden_states = torch.randn(1, 77, 1024).to(\"cuda\", dtype=torch.float16)\n", " \n", " with torch.no_grad():\n", " output = new_unet(\n", " test_latent, \n", " timesteps, \n", " encoder_hidden_states\n", " ).sample\n", "\n", " print(f\"Output shape: {output.shape}\")\n", " new_unet.save_pretrained(checkpoint_path)\n", " print(new_unet) " ] }, { "cell_type": "code", "execution_count": null, "id": "0a1d32f6-23ea-4f6f-b2e8-9584a0c12a0d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }