recoilme commited on
Commit
5bbef16
·
1 Parent(s): b9a43be
.gitignore CHANGED
@@ -7,7 +7,7 @@ __pycache__/
7
  src/samples
8
  # cache
9
  cache
10
- datasets
11
  test
12
  wandb
13
  nohup.out
 
7
  src/samples
8
  # cache
9
  cache
10
+ # datasets
11
  test
12
  wandb
13
  nohup.out
butterfly.zip → datasets/butterfly/data-00000-of-00001.arrow RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5b923bef9a5d1fe7103e960c943c110ec46155fc71d7f45e0070f3ef072bbdcb
3
- size 237918081
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8479e8b4cf0c3505189c608cedf8b35ab073f14c6b7db0a9e66b75925e1c519
3
+ size 53255512
datasets/butterfly/dataset_info.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "image_path": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "text": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "vae": {
14
+ "feature": {
15
+ "feature": {
16
+ "feature": {
17
+ "dtype": "float16",
18
+ "_type": "Value"
19
+ },
20
+ "_type": "List"
21
+ },
22
+ "_type": "List"
23
+ },
24
+ "_type": "List"
25
+ },
26
+ "embeddings": {
27
+ "feature": {
28
+ "feature": {
29
+ "dtype": "float32",
30
+ "_type": "Value"
31
+ },
32
+ "_type": "List"
33
+ },
34
+ "_type": "List"
35
+ },
36
+ "width": {
37
+ "dtype": "int64",
38
+ "_type": "Value"
39
+ },
40
+ "height": {
41
+ "dtype": "int64",
42
+ "_type": "Value"
43
+ }
44
+ },
45
+ "homepage": "",
46
+ "license": ""
47
+ }
datasets/butterfly/state.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "23217366db2250df",
8
+ "_format_columns": [
9
+ "image_path",
10
+ "text",
11
+ "vae",
12
+ "embeddings",
13
+ "width",
14
+ "height"
15
+ ],
16
+ "_format_kwargs": {},
17
+ "_format_type": null,
18
+ "_output_all_columns": false,
19
+ "_split": null
20
+ }
requirements.txt CHANGED
@@ -1,6 +1,3 @@
1
- # torch>=2.6.0
2
- # torchvision>=0.21.0
3
- # torchaudio>=2.6.0
4
  diffusers>=0.32.2
5
  accelerate>=1.5.2
6
  datasets>=3.5.0
 
 
 
 
1
  diffusers>=0.32.2
2
  accelerate>=1.5.2
3
  datasets>=3.5.0
samples/unet_192x384_0.jpg ADDED

Git LFS Details

  • SHA256: f8d593e4370ad0523c154177853b0a0494792995654d3d7d1fc8baf43e519d35
  • Pointer size: 130 Bytes
  • Size of remote file: 26 kB
samples/unet_256x384_0.jpg ADDED

Git LFS Details

  • SHA256: 3d972079e6277012a49509f8d529cbb5540acdaa1e32267c3879db2ecce1c7a3
  • Pointer size: 130 Bytes
  • Size of remote file: 46.7 kB
samples/unet_320x384_0.jpg ADDED

Git LFS Details

  • SHA256: 7a9ea056f95fdc5041d89e342a6ba83f37de3280771b4d3907a51d558d75bf83
  • Pointer size: 130 Bytes
  • Size of remote file: 59.7 kB
samples/unet_384x192_0.jpg ADDED

Git LFS Details

  • SHA256: 55b102073dcf0bc45d7b7ca96d0b151d638f7875f1da86c5f2ea44fddf1e2e72
  • Pointer size: 130 Bytes
  • Size of remote file: 26.4 kB
samples/unet_384x256_0.jpg ADDED

Git LFS Details

  • SHA256: 3f9ab248800ebf6c8c52b88ed8ee3bd27ec6f104ceae35d04944f988f6b99c33
  • Pointer size: 130 Bytes
  • Size of remote file: 25.1 kB
samples/unet_384x320_0.jpg ADDED

Git LFS Details

  • SHA256: acc0960f81837a0e161f9faa57916819f0325015720173cc1a13b997e0aa0631
  • Pointer size: 130 Bytes
  • Size of remote file: 55.3 kB
src/dataset_from_folder.py CHANGED
@@ -24,10 +24,8 @@ batch_size = 5
24
  min_size = 192 #256 #192
25
  max_size = 384 #256 #384
26
  step = 64
27
- img_share = 1.0
28
  empty_share = 0.05
29
  limit = 0
30
- textemb_full = False
31
  # Основная процедура обработки
32
  folder_path = "/workspace/butterfly" #alchemist"
33
  save_path = "/workspace/sdxs3d/datasets/butterfly" #"alchemist"
@@ -44,18 +42,13 @@ def clear_cuda_memory():
44
  # ---------------- 2️⃣ Загрузка моделей ----------------
45
  def load_models():
46
  print("Загрузка моделей...")
47
- #vae = AutoencoderKLWan.from_pretrained("AiArtLab/simplevae",subfolder="wan16x_vae_nightly",torch_dtype=dtype).to(device).eval()
48
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", subfolder=None,torch_dtype=dtype).to(device).eval()
49
-
50
- #vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
51
- #vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell",subfolder="vae",torch_dtype=dtype).to(device).eval()
52
- #vae = AutoencoderKL.from_pretrained("/home/recoilme/sdxs/vae", variant="fp16",torch_dtype=dtype).to(device).eval()
53
- model = AutoModel.from_pretrained("visheratin/mexma-siglip2", dtype=dtype, trust_remote_code=True, optimized=True).to(device).eval()
54
- processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2", use_fast=True)
55
- tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2")
56
- return vae, model, processor, tokenizer
57
 
58
- vae, model, processor, tokenizer = load_models()
59
 
60
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
61
  if shift_factor is None:
@@ -124,57 +117,39 @@ def get_image_transform(min_size=256, max_size=512, step=64):
124
  return transform
125
 
126
  # ---------------- 4️⃣ Функции обработки ----------------
127
- def encode_images_batch(images, processor, model, empty_share=0.0):
128
- """
129
- images: список PIL.Image
130
- processor: трансформер для препроцессинга изображений
131
- model: vision encoder (например, CLIP или подобный)
132
- empty_share: доля эмбеддингов, которые нужно обнулить
133
- """
134
- # Преобразуем весь батч сразу (вместо обхода по каждому изображению)
135
- processed = processor(images=images, return_tensors="pt")
136
- pixel_values = processed["pixel_values"].to(device, dtype)
137
-
 
138
  with torch.inference_mode():
139
- outputs = model.vision_model(pixel_values)
140
- #hidden_states = outputs.last_hidden_state # [B, seq_len, dim]
141
- pooled = outputs.pooler_output # [B, dim]
142
-
143
- # Добавляем pooled embedding в конец sequence
144
- #context = torch.cat([hidden_states, pooled.unsqueeze(1)], dim=1) # [B, seq_len+1, dim]
145
- context = pooled.unsqueeze(1)
 
146
 
147
- # Добавляем нулевые эмбеддинги с вероятностью empty_share
148
- if empty_share > 0:
149
- batch_size = context.shape[0]
150
- num_empty = int(batch_size * empty_share)
151
- if num_empty > 0:
152
- zero_embeddings = torch.zeros_like(context[:num_empty])
153
- context[:num_empty] = zero_embeddings
154
 
155
- # Преобразуем bfloat16 в float32 если нужно
156
- if context.dtype == torch.bfloat16:
157
- context = context.to(torch.float32)
158
 
159
- return context.cpu().numpy() # [B, seq_len+1, dim]
 
 
160
 
161
-
162
- def encode_texts_batch(texts, tokenizer, model):
163
- with torch.inference_mode():
164
- text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",
165
- max_length=512,
166
- truncation=True).to(device)
167
- text_embeddings = model.encode_texts(text_tokenized.input_ids, text_tokenized.attention_mask)
168
- return text_embeddings.unsqueeze(1).cpu().numpy()
169
-
170
- def encode_texts_batch_full(texts, tokenizer, model):
171
- with torch.inference_mode():
172
- text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",max_length=512,truncation=True).to(device)
173
- features = model.text_model(
174
- input_ids=text_tokenized.input_ids, attention_mask=text_tokenized.attention_mask
175
- ).last_hidden_state
176
- features_proj = model.text_projector(features)
177
- return features_proj.cpu().numpy()
178
 
179
  def clean_label(label):
180
  label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
@@ -236,28 +211,15 @@ def encode_to_latents(images, texts):
236
  # Кодируем батч
237
  with torch.no_grad():
238
  posteriors = vae.encode(batch_tensor).latent_dist.mode()
239
-
240
  latents = (posteriors - shift_factor) / scaling_factor
241
-
242
- if latents_mean!=None and latents_std!=None:
243
- latents = (latents - torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)) / torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
244
- #print(latents.ndim, latents.shape)
245
- if latents.ndim==5:
246
- latents = latents[:, :, 0, :, :] # Убираем временную ось [B, C, H, W]
247
 
248
  latents_np = latents.to(dtype).cpu().numpy()
249
 
250
  # Обрабатываем тексты
251
  text_labels = [clean_label(text) for text in texts]
252
- if random.random() < img_share:
253
- embeddings = encode_images_batch(pil_images, processor, model)
254
- text_labels = [f"img: {label}" for label in text_labels]
255
- else:
256
- model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
257
- if textemb_full:
258
- embeddings = encode_texts_batch_full(model_prompts, tokenizer, model)
259
- else:
260
- embeddings = encode_texts_batch(model_prompts, tokenizer, model)
261
 
262
  return {
263
  "vae": latents_np,
 
24
  min_size = 192 #256 #192
25
  max_size = 384 #256 #384
26
  step = 64
 
27
  empty_share = 0.05
28
  limit = 0
 
29
  # Основная процедура обработки
30
  folder_path = "/workspace/butterfly" #alchemist"
31
  save_path = "/workspace/sdxs3d/datasets/butterfly" #"alchemist"
 
42
  # ---------------- 2️⃣ Загрузка моделей ----------------
43
  def load_models():
44
  print("Загрузка моделей...")
45
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
48
+ model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
49
+ return vae, model, tokenizer
 
 
 
 
 
50
 
51
+ vae, model, tokenizer = load_models()
52
 
53
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
54
  if shift_factor is None:
 
117
  return transform
118
 
119
  # ---------------- 4️⃣ Функции обработки ----------------
120
+ def last_token_pool(last_hidden_states: torch.Tensor,
121
+ attention_mask: torch.Tensor) -> torch.Tensor:
122
+ # Определяем, есть ли left padding
123
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
124
+ if left_padding:
125
+ return last_hidden_states[:, -1]
126
+ else:
127
+ sequence_lengths = attention_mask.sum(dim=1) - 1
128
+ batch_size = last_hidden_states.shape[0]
129
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
130
+
131
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=512, normalize=False):
132
  with torch.inference_mode():
133
+ # Токенизация
134
+ batch = tokenizer(
135
+ texts,
136
+ return_tensors="pt",
137
+ padding="max_length",
138
+ truncation=True,
139
+ max_length=max_length
140
+ ).to(device)
141
 
142
+ # Прогон через модель
143
+ outputs = model(**batch)
 
 
 
 
 
144
 
145
+ # Пулинг по last token
146
+ embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
 
147
 
148
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
149
+ if normalize:
150
+ embeddings = F.normalize(embeddings, p=2, dim=1)
151
 
152
+ return embeddings.unsqueeze(1).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def clean_label(label):
155
  label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
 
211
  # Кодируем батч
212
  with torch.no_grad():
213
  posteriors = vae.encode(batch_tensor).latent_dist.mode()
 
214
  latents = (posteriors - shift_factor) / scaling_factor
 
 
 
 
 
 
215
 
216
  latents_np = latents.to(dtype).cpu().numpy()
217
 
218
  # Обрабатываем тексты
219
  text_labels = [clean_label(text) for text in texts]
220
+
221
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
222
+ embeddings = encode_texts_batch(model_prompts, tokenizer, model)
 
 
 
 
 
 
223
 
224
  return {
225
  "vae": latents_np,
src/dataset_sample.ipynb CHANGED
@@ -202,12 +202,8 @@
202
  " \n",
203
  " # Загрузка VAE модели\n",
204
  " print(\"Загрузка VAE модели...\")\n",
205
- " #vae = AutoencoderKLWan.from_pretrained(\n",
206
- " # \"AiArtLab/simplevae\", subfolder=\"wan16x_vae_nightly\",\n",
207
- " # torch_dtype=dtype\n",
208
- " #).to(device).eval()\n",
209
- " vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", subfolder=None,torch_dtype=dtype).to(device).eval()\n",
210
- "\n",
211
  " shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
212
  " if shift_factor is None:\n",
213
  " shift_factor = 0.0\n",
@@ -248,8 +244,6 @@
248
  " print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
249
  " \n",
250
  " latent = torch.tensor(example[\"vae\"], dtype=dtype).to(device)\n",
251
- " #if latent.ndim == 3:\n",
252
- " # latent = latent.unsqueeze(1)\n",
253
  " # Латент в форме [C, T, H, W]\n",
254
  " print(latent.ndim, latent.shape)\n",
255
  " with torch.no_grad():\n",
@@ -331,7 +325,7 @@
331
  "name": "python",
332
  "nbconvert_exporter": "python",
333
  "pygments_lexer": "ipython3",
334
- "version": "3.11.10"
335
  }
336
  },
337
  "nbformat": 4,
 
202
  " \n",
203
  " # Загрузка VAE модели\n",
204
  " print(\"Загрузка VAE модели...\")\n",
205
+ " vae = AutoencoderKL.from_pretrained(\"AiArtLab/simplevae\",subfolder=\"simple_vae_nightly\",torch_dtype=dtype).to(device).eval()\n",
206
+ " \n",
 
 
 
 
207
  " shift_factor = getattr(vae.config, \"shift_factor\", 0.0)\n",
208
  " if shift_factor is None:\n",
209
  " shift_factor = 0.0\n",
 
244
  " print(f\"\\n--- Батч {width}x{height}: {count} примеров ---\")\n",
245
  " \n",
246
  " latent = torch.tensor(example[\"vae\"], dtype=dtype).to(device)\n",
 
 
247
  " # Латент в форме [C, T, H, W]\n",
248
  " print(latent.ndim, latent.shape)\n",
249
  " with torch.no_grad():\n",
 
325
  "name": "python",
326
  "nbconvert_exporter": "python",
327
  "pygments_lexer": "ipython3",
328
+ "version": "3.11.11"
329
  }
330
  },
331
  "nbformat": 4,
src/model_create.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "id": "5212f806-14b4-4b5f-bcb4-09e36df3b7d9",
7
  "metadata": {},
8
  "outputs": [
@@ -11,164 +11,223 @@
11
  "output_type": "stream",
12
  "text": [
13
  "test unet\n",
14
- "Количество параметров: 1616742724\n",
15
- "Output shape: torch.Size([1, 4, 60, 48])\n",
16
  "UNet2DConditionModel(\n",
17
- " (conv_in): Conv2d(4, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
  " (time_proj): Timesteps()\n",
19
  " (time_embedding): TimestepEmbedding(\n",
20
- " (linear_1): Linear(in_features=288, out_features=1152, bias=True)\n",
21
  " (act): SiLU()\n",
22
- " (linear_2): Linear(in_features=1152, out_features=1152, bias=True)\n",
23
  " )\n",
24
  " (down_blocks): ModuleList(\n",
25
  " (0): DownBlock2D(\n",
26
  " (resnets): ModuleList(\n",
27
  " (0-1): 2 x ResnetBlock2D(\n",
28
- " (norm1): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
29
- " (conv1): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
- " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
31
- " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
32
  " (dropout): Dropout(p=0.0, inplace=False)\n",
33
- " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
  " (nonlinearity): SiLU()\n",
35
  " )\n",
36
  " )\n",
37
  " (downsamplers): ModuleList(\n",
38
  " (0): Downsample2D(\n",
39
- " (conv): Conv2d(288, 288, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
  " )\n",
41
  " )\n",
42
  " )\n",
43
  " (1): CrossAttnDownBlock2D(\n",
44
  " (attentions): ModuleList(\n",
45
  " (0-1): 2 x Transformer2DModel(\n",
46
- " (norm): GroupNorm(32, 576, eps=1e-06, affine=True)\n",
47
- " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
48
  " (transformer_blocks): ModuleList(\n",
49
  " (0): BasicTransformerBlock(\n",
50
- " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
51
  " (attn1): Attention(\n",
52
- " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
53
- " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
54
- " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
55
  " (to_out): ModuleList(\n",
56
- " (0): Linear(in_features=576, out_features=576, bias=True)\n",
57
  " (1): Dropout(p=0.0, inplace=False)\n",
58
  " )\n",
59
  " )\n",
60
- " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
61
  " (attn2): Attention(\n",
62
- " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
63
- " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
64
- " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
65
  " (to_out): ModuleList(\n",
66
- " (0): Linear(in_features=576, out_features=576, bias=True)\n",
67
  " (1): Dropout(p=0.0, inplace=False)\n",
68
  " )\n",
69
  " )\n",
70
- " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
71
  " (ff): FeedForward(\n",
72
  " (net): ModuleList(\n",
73
  " (0): GEGLU(\n",
74
- " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
75
  " )\n",
76
  " (1): Dropout(p=0.0, inplace=False)\n",
77
- " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
78
  " )\n",
79
  " )\n",
80
  " )\n",
81
  " )\n",
82
- " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
83
  " )\n",
84
  " )\n",
85
  " (resnets): ModuleList(\n",
86
  " (0): ResnetBlock2D(\n",
87
- " (norm1): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
88
- " (conv1): Conv2d(288, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
89
- " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
90
- " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
91
  " (dropout): Dropout(p=0.0, inplace=False)\n",
92
- " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
93
  " (nonlinearity): SiLU()\n",
94
- " (conv_shortcut): Conv2d(288, 576, kernel_size=(1, 1), stride=(1, 1))\n",
95
  " )\n",
96
  " (1): ResnetBlock2D(\n",
97
- " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
98
- " (conv1): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
99
- " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
100
- " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
101
  " (dropout): Dropout(p=0.0, inplace=False)\n",
102
- " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
103
  " (nonlinearity): SiLU()\n",
104
  " )\n",
105
  " )\n",
106
  " (downsamplers): ModuleList(\n",
107
  " (0): Downsample2D(\n",
108
- " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
109
  " )\n",
110
  " )\n",
111
  " )\n",
112
  " (2): CrossAttnDownBlock2D(\n",
113
  " (attentions): ModuleList(\n",
114
  " (0-1): 2 x Transformer2DModel(\n",
115
- " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
116
- " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
117
  " (transformer_blocks): ModuleList(\n",
118
- " (0-7): 8 x BasicTransformerBlock(\n",
119
- " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
120
  " (attn1): Attention(\n",
121
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
122
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
123
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
124
  " (to_out): ModuleList(\n",
125
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
126
  " (1): Dropout(p=0.0, inplace=False)\n",
127
  " )\n",
128
  " )\n",
129
- " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
130
  " (attn2): Attention(\n",
131
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
132
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
133
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
134
  " (to_out): ModuleList(\n",
135
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
136
  " (1): Dropout(p=0.0, inplace=False)\n",
137
  " )\n",
138
  " )\n",
139
- " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
140
  " (ff): FeedForward(\n",
141
  " (net): ModuleList(\n",
142
  " (0): GEGLU(\n",
143
- " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
144
  " )\n",
145
  " (1): Dropout(p=0.0, inplace=False)\n",
146
- " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
147
  " )\n",
148
  " )\n",
149
  " )\n",
150
  " )\n",
151
- " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
152
  " )\n",
153
  " )\n",
154
  " (resnets): ModuleList(\n",
155
  " (0): ResnetBlock2D(\n",
156
- " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
157
- " (conv1): Conv2d(576, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
158
- " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
159
- " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
160
  " (dropout): Dropout(p=0.0, inplace=False)\n",
161
- " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
  " (nonlinearity): SiLU()\n",
163
- " (conv_shortcut): Conv2d(576, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
164
  " )\n",
165
  " (1): ResnetBlock2D(\n",
166
- " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
167
- " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
168
- " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
169
- " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
170
  " (dropout): Dropout(p=0.0, inplace=False)\n",
171
- " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  " (nonlinearity): SiLU()\n",
173
  " )\n",
174
  " )\n",
@@ -178,174 +237,234 @@
178
  " (0): CrossAttnUpBlock2D(\n",
179
  " (attentions): ModuleList(\n",
180
  " (0-2): 3 x Transformer2DModel(\n",
181
- " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
182
- " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
183
  " (transformer_blocks): ModuleList(\n",
184
  " (0-7): 8 x BasicTransformerBlock(\n",
185
- " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
186
  " (attn1): Attention(\n",
187
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
188
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
189
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
190
  " (to_out): ModuleList(\n",
191
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
192
  " (1): Dropout(p=0.0, inplace=False)\n",
193
  " )\n",
194
  " )\n",
195
- " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
196
  " (attn2): Attention(\n",
197
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
198
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
199
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
200
  " (to_out): ModuleList(\n",
201
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
202
  " (1): Dropout(p=0.0, inplace=False)\n",
203
  " )\n",
204
  " )\n",
205
- " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
206
  " (ff): FeedForward(\n",
207
  " (net): ModuleList(\n",
208
  " (0): GEGLU(\n",
209
- " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
210
  " )\n",
211
  " (1): Dropout(p=0.0, inplace=False)\n",
212
- " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
213
  " )\n",
214
  " )\n",
215
  " )\n",
216
  " )\n",
217
- " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  " )\n",
219
  " )\n",
220
  " (resnets): ModuleList(\n",
221
  " (0-1): 2 x ResnetBlock2D(\n",
222
- " (norm1): GroupNorm(32, 2304, eps=1e-05, affine=True)\n",
223
- " (conv1): Conv2d(2304, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
224
- " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
225
- " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
226
  " (dropout): Dropout(p=0.0, inplace=False)\n",
227
- " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
228
  " (nonlinearity): SiLU()\n",
229
- " (conv_shortcut): Conv2d(2304, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
230
  " )\n",
231
  " (2): ResnetBlock2D(\n",
232
- " (norm1): GroupNorm(32, 1728, eps=1e-05, affine=True)\n",
233
- " (conv1): Conv2d(1728, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
234
- " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
235
- " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
236
  " (dropout): Dropout(p=0.0, inplace=False)\n",
237
- " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
238
  " (nonlinearity): SiLU()\n",
239
- " (conv_shortcut): Conv2d(1728, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
240
  " )\n",
241
  " )\n",
242
  " (upsamplers): ModuleList(\n",
243
  " (0): Upsample2D(\n",
244
- " (conv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
245
  " )\n",
246
  " )\n",
247
  " )\n",
248
- " (1): CrossAttnUpBlock2D(\n",
249
  " (attentions): ModuleList(\n",
250
  " (0-2): 3 x Transformer2DModel(\n",
251
- " (norm): GroupNorm(32, 576, eps=1e-06, affine=True)\n",
252
- " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
253
  " (transformer_blocks): ModuleList(\n",
254
  " (0): BasicTransformerBlock(\n",
255
- " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
256
  " (attn1): Attention(\n",
257
- " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
258
- " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
259
- " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
260
  " (to_out): ModuleList(\n",
261
- " (0): Linear(in_features=576, out_features=576, bias=True)\n",
262
  " (1): Dropout(p=0.0, inplace=False)\n",
263
  " )\n",
264
  " )\n",
265
- " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
266
  " (attn2): Attention(\n",
267
- " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
268
- " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
269
- " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
270
  " (to_out): ModuleList(\n",
271
- " (0): Linear(in_features=576, out_features=576, bias=True)\n",
272
  " (1): Dropout(p=0.0, inplace=False)\n",
273
  " )\n",
274
  " )\n",
275
- " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
276
  " (ff): FeedForward(\n",
277
  " (net): ModuleList(\n",
278
  " (0): GEGLU(\n",
279
- " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
280
  " )\n",
281
  " (1): Dropout(p=0.0, inplace=False)\n",
282
- " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
283
  " )\n",
284
  " )\n",
285
  " )\n",
286
  " )\n",
287
- " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
288
  " )\n",
289
  " )\n",
290
  " (resnets): ModuleList(\n",
291
  " (0): ResnetBlock2D(\n",
292
- " (norm1): GroupNorm(32, 1728, eps=1e-05, affine=True)\n",
293
- " (conv1): Conv2d(1728, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
294
- " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
295
- " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
296
  " (dropout): Dropout(p=0.0, inplace=False)\n",
297
- " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
298
  " (nonlinearity): SiLU()\n",
299
- " (conv_shortcut): Conv2d(1728, 576, kernel_size=(1, 1), stride=(1, 1))\n",
300
  " )\n",
301
  " (1): ResnetBlock2D(\n",
302
- " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
303
- " (conv1): Conv2d(1152, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
304
- " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
305
- " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
306
  " (dropout): Dropout(p=0.0, inplace=False)\n",
307
- " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
308
  " (nonlinearity): SiLU()\n",
309
- " (conv_shortcut): Conv2d(1152, 576, kernel_size=(1, 1), stride=(1, 1))\n",
310
  " )\n",
311
  " (2): ResnetBlock2D(\n",
312
- " (norm1): GroupNorm(32, 864, eps=1e-05, affine=True)\n",
313
- " (conv1): Conv2d(864, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
314
- " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
315
- " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
316
  " (dropout): Dropout(p=0.0, inplace=False)\n",
317
- " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
318
  " (nonlinearity): SiLU()\n",
319
- " (conv_shortcut): Conv2d(864, 576, kernel_size=(1, 1), stride=(1, 1))\n",
320
  " )\n",
321
  " )\n",
322
  " (upsamplers): ModuleList(\n",
323
  " (0): Upsample2D(\n",
324
- " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
325
  " )\n",
326
  " )\n",
327
  " )\n",
328
- " (2): UpBlock2D(\n",
329
  " (resnets): ModuleList(\n",
330
  " (0): ResnetBlock2D(\n",
331
- " (norm1): GroupNorm(32, 864, eps=1e-05, affine=True)\n",
332
- " (conv1): Conv2d(864, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
333
- " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
334
- " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
335
  " (dropout): Dropout(p=0.0, inplace=False)\n",
336
- " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
337
  " (nonlinearity): SiLU()\n",
338
- " (conv_shortcut): Conv2d(864, 288, kernel_size=(1, 1), stride=(1, 1))\n",
339
  " )\n",
340
  " (1-2): 2 x ResnetBlock2D(\n",
341
- " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
342
- " (conv1): Conv2d(576, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
343
- " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
344
- " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
345
  " (dropout): Dropout(p=0.0, inplace=False)\n",
346
- " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
347
  " (nonlinearity): SiLU()\n",
348
- " (conv_shortcut): Conv2d(576, 288, kernel_size=(1, 1), stride=(1, 1))\n",
349
  " )\n",
350
  " )\n",
351
  " )\n",
@@ -353,60 +472,60 @@
353
  " (mid_block): UNetMidBlock2DCrossAttn(\n",
354
  " (attentions): ModuleList(\n",
355
  " (0): Transformer2DModel(\n",
356
- " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
357
- " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
358
  " (transformer_blocks): ModuleList(\n",
359
  " (0-7): 8 x BasicTransformerBlock(\n",
360
- " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
361
  " (attn1): Attention(\n",
362
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
363
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
364
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
365
  " (to_out): ModuleList(\n",
366
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
367
  " (1): Dropout(p=0.0, inplace=False)\n",
368
  " )\n",
369
  " )\n",
370
- " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
371
  " (attn2): Attention(\n",
372
- " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
373
- " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
374
- " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
375
  " (to_out): ModuleList(\n",
376
- " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
377
  " (1): Dropout(p=0.0, inplace=False)\n",
378
  " )\n",
379
  " )\n",
380
- " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
381
  " (ff): FeedForward(\n",
382
  " (net): ModuleList(\n",
383
  " (0): GEGLU(\n",
384
- " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
385
  " )\n",
386
  " (1): Dropout(p=0.0, inplace=False)\n",
387
- " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
388
  " )\n",
389
  " )\n",
390
  " )\n",
391
  " )\n",
392
- " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
393
  " )\n",
394
  " )\n",
395
  " (resnets): ModuleList(\n",
396
  " (0-1): 2 x ResnetBlock2D(\n",
397
- " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
398
- " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
399
- " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
400
- " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
401
  " (dropout): Dropout(p=0.0, inplace=False)\n",
402
- " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
403
  " (nonlinearity): SiLU()\n",
404
  " )\n",
405
  " )\n",
406
  " )\n",
407
- " (conv_norm_out): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
408
  " (conv_act): SiLU()\n",
409
- " (conv_out): Conv2d(288, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
410
  ")\n"
411
  ]
412
  }
@@ -414,11 +533,11 @@
414
  "source": [
415
  "config_sdxs = {\n",
416
  " # === Основные размеры и каналы ===\n",
417
- " \"in_channels\": 4, # Количество входных каналов (совместимость с VAE)\n",
418
- " \"out_channels\": 4, # Количество выходных каналов (симметрично in_channels) \n",
419
  "\n",
420
  " # === Cross-Attention ===\n",
421
- " \"cross_attention_dim\": 1152, # Размерность текстовых эмбеддингов\n",
422
  " \"use_linear_projection\": True,\n",
423
  " \"norm_num_groups\": 32,\n",
424
  " \n",
@@ -427,20 +546,20 @@
427
  " \"DownBlock2D\",\n",
428
  " \"CrossAttnDownBlock2D\",\n",
429
  " \"CrossAttnDownBlock2D\",\n",
430
- " #\"CrossAttnDownBlock2D\",\n",
431
  " ],\n",
432
  " \"up_block_types\": [ # декодер\n",
433
- " #\"CrossAttnUpBlock2D\",\n",
434
  " \"CrossAttnUpBlock2D\",\n",
435
  " \"CrossAttnUpBlock2D\",\n",
436
  " \"UpBlock2D\",\n",
437
  " ],\n",
438
  "\n",
439
  " # === Конфигурация каналов ===\n",
440
- " \"block_out_channels\": [288, 576, 1152],\n",
441
  "\n",
442
- " \"transformer_layers_per_block\": [1, 1, 8],\n",
443
- " \"attention_head_dim\": [6, 9, 18],\n",
444
  "}\n",
445
  "\n",
446
  "def check_initialization(model):\n",
@@ -465,9 +584,9 @@
465
  " print(f\"Количество параметров: {num_params}\")\n",
466
  "\n",
467
  " # Генерация тестового латента (640x512 в latent space)\n",
468
- " test_latent = torch.randn(1,4, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
469
  " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
470
- " encoder_hidden_states = torch.randn(1, 77, 1152).to(\"cuda\", dtype=torch.float16)\n",
471
  " \n",
472
  " with torch.no_grad():\n",
473
  " output = new_unet(\n",
@@ -506,7 +625,7 @@
506
  "name": "python",
507
  "nbconvert_exporter": "python",
508
  "pygments_lexer": "ipython3",
509
- "version": "3.11.10"
510
  }
511
  },
512
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "5212f806-14b4-4b5f-bcb4-09e36df3b7d9",
7
  "metadata": {},
8
  "outputs": [
 
11
  "output_type": "stream",
12
  "text": [
13
  "test unet\n",
14
+ "Количество параметров: 1546186256\n",
15
+ "Output shape: torch.Size([1, 16, 60, 48])\n",
16
  "UNet2DConditionModel(\n",
17
+ " (conv_in): Conv2d(16, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
  " (time_proj): Timesteps()\n",
19
  " (time_embedding): TimestepEmbedding(\n",
20
+ " (linear_1): Linear(in_features=256, out_features=1024, bias=True)\n",
21
  " (act): SiLU()\n",
22
+ " (linear_2): Linear(in_features=1024, out_features=1024, bias=True)\n",
23
  " )\n",
24
  " (down_blocks): ModuleList(\n",
25
  " (0): DownBlock2D(\n",
26
  " (resnets): ModuleList(\n",
27
  " (0-1): 2 x ResnetBlock2D(\n",
28
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
29
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
+ " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
31
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
32
  " (dropout): Dropout(p=0.0, inplace=False)\n",
33
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
  " (nonlinearity): SiLU()\n",
35
  " )\n",
36
  " )\n",
37
  " (downsamplers): ModuleList(\n",
38
  " (0): Downsample2D(\n",
39
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
  " )\n",
41
  " )\n",
42
  " )\n",
43
  " (1): CrossAttnDownBlock2D(\n",
44
  " (attentions): ModuleList(\n",
45
  " (0-1): 2 x Transformer2DModel(\n",
46
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
47
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
48
  " (transformer_blocks): ModuleList(\n",
49
  " (0): BasicTransformerBlock(\n",
50
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
51
  " (attn1): Attention(\n",
52
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
53
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
54
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
55
  " (to_out): ModuleList(\n",
56
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
57
  " (1): Dropout(p=0.0, inplace=False)\n",
58
  " )\n",
59
  " )\n",
60
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
61
  " (attn2): Attention(\n",
62
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
63
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
64
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
65
  " (to_out): ModuleList(\n",
66
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
67
  " (1): Dropout(p=0.0, inplace=False)\n",
68
  " )\n",
69
  " )\n",
70
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
71
  " (ff): FeedForward(\n",
72
  " (net): ModuleList(\n",
73
  " (0): GEGLU(\n",
74
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
75
  " )\n",
76
  " (1): Dropout(p=0.0, inplace=False)\n",
77
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
78
  " )\n",
79
  " )\n",
80
  " )\n",
81
  " )\n",
82
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
83
  " )\n",
84
  " )\n",
85
  " (resnets): ModuleList(\n",
86
  " (0): ResnetBlock2D(\n",
87
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
88
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
89
+ " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n",
90
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
91
  " (dropout): Dropout(p=0.0, inplace=False)\n",
92
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
93
  " (nonlinearity): SiLU()\n",
94
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
95
  " )\n",
96
  " (1): ResnetBlock2D(\n",
97
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
98
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
99
+ " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n",
100
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
101
  " (dropout): Dropout(p=0.0, inplace=False)\n",
102
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
103
  " (nonlinearity): SiLU()\n",
104
  " )\n",
105
  " )\n",
106
  " (downsamplers): ModuleList(\n",
107
  " (0): Downsample2D(\n",
108
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
109
  " )\n",
110
  " )\n",
111
  " )\n",
112
  " (2): CrossAttnDownBlock2D(\n",
113
  " (attentions): ModuleList(\n",
114
  " (0-1): 2 x Transformer2DModel(\n",
115
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
116
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
117
  " (transformer_blocks): ModuleList(\n",
118
+ " (0): BasicTransformerBlock(\n",
119
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
120
  " (attn1): Attention(\n",
121
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
122
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
123
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
124
  " (to_out): ModuleList(\n",
125
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
126
  " (1): Dropout(p=0.0, inplace=False)\n",
127
  " )\n",
128
  " )\n",
129
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
130
  " (attn2): Attention(\n",
131
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
132
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
133
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
134
  " (to_out): ModuleList(\n",
135
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
136
  " (1): Dropout(p=0.0, inplace=False)\n",
137
  " )\n",
138
  " )\n",
139
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
140
  " (ff): FeedForward(\n",
141
  " (net): ModuleList(\n",
142
  " (0): GEGLU(\n",
143
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
144
  " )\n",
145
  " (1): Dropout(p=0.0, inplace=False)\n",
146
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
147
  " )\n",
148
  " )\n",
149
  " )\n",
150
  " )\n",
151
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
152
  " )\n",
153
  " )\n",
154
  " (resnets): ModuleList(\n",
155
  " (0): ResnetBlock2D(\n",
156
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
157
+ " (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
158
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
159
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
160
  " (dropout): Dropout(p=0.0, inplace=False)\n",
161
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
  " (nonlinearity): SiLU()\n",
163
+ " (conv_shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
164
  " )\n",
165
  " (1): ResnetBlock2D(\n",
166
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
167
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
168
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
169
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
170
  " (dropout): Dropout(p=0.0, inplace=False)\n",
171
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
172
+ " (nonlinearity): SiLU()\n",
173
+ " )\n",
174
+ " )\n",
175
+ " (downsamplers): ModuleList(\n",
176
+ " (0): Downsample2D(\n",
177
+ " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
178
+ " )\n",
179
+ " )\n",
180
+ " )\n",
181
+ " (3): CrossAttnDownBlock2D(\n",
182
+ " (attentions): ModuleList(\n",
183
+ " (0-1): 2 x Transformer2DModel(\n",
184
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
185
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
186
+ " (transformer_blocks): ModuleList(\n",
187
+ " (0-7): 8 x BasicTransformerBlock(\n",
188
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
189
+ " (attn1): Attention(\n",
190
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
191
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
192
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
193
+ " (to_out): ModuleList(\n",
194
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
195
+ " (1): Dropout(p=0.0, inplace=False)\n",
196
+ " )\n",
197
+ " )\n",
198
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
199
+ " (attn2): Attention(\n",
200
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
201
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
202
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
203
+ " (to_out): ModuleList(\n",
204
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
205
+ " (1): Dropout(p=0.0, inplace=False)\n",
206
+ " )\n",
207
+ " )\n",
208
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
209
+ " (ff): FeedForward(\n",
210
+ " (net): ModuleList(\n",
211
+ " (0): GEGLU(\n",
212
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
213
+ " )\n",
214
+ " (1): Dropout(p=0.0, inplace=False)\n",
215
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
216
+ " )\n",
217
+ " )\n",
218
+ " )\n",
219
+ " )\n",
220
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
221
+ " )\n",
222
+ " )\n",
223
+ " (resnets): ModuleList(\n",
224
+ " (0-1): 2 x ResnetBlock2D(\n",
225
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
226
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
227
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
228
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
229
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
230
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
231
  " (nonlinearity): SiLU()\n",
232
  " )\n",
233
  " )\n",
 
237
  " (0): CrossAttnUpBlock2D(\n",
238
  " (attentions): ModuleList(\n",
239
  " (0-2): 3 x Transformer2DModel(\n",
240
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
241
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
242
  " (transformer_blocks): ModuleList(\n",
243
  " (0-7): 8 x BasicTransformerBlock(\n",
244
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
245
  " (attn1): Attention(\n",
246
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
247
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
248
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
249
  " (to_out): ModuleList(\n",
250
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
251
  " (1): Dropout(p=0.0, inplace=False)\n",
252
  " )\n",
253
  " )\n",
254
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
255
  " (attn2): Attention(\n",
256
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
257
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
258
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
259
  " (to_out): ModuleList(\n",
260
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
261
  " (1): Dropout(p=0.0, inplace=False)\n",
262
  " )\n",
263
  " )\n",
264
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
265
  " (ff): FeedForward(\n",
266
  " (net): ModuleList(\n",
267
  " (0): GEGLU(\n",
268
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
269
  " )\n",
270
  " (1): Dropout(p=0.0, inplace=False)\n",
271
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
272
  " )\n",
273
  " )\n",
274
  " )\n",
275
  " )\n",
276
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
277
+ " )\n",
278
+ " )\n",
279
+ " (resnets): ModuleList(\n",
280
+ " (0-2): 3 x ResnetBlock2D(\n",
281
+ " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n",
282
+ " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
283
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
284
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
285
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
286
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
287
+ " (nonlinearity): SiLU()\n",
288
+ " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
289
+ " )\n",
290
+ " )\n",
291
+ " (upsamplers): ModuleList(\n",
292
+ " (0): Upsample2D(\n",
293
+ " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
294
+ " )\n",
295
+ " )\n",
296
+ " )\n",
297
+ " (1): CrossAttnUpBlock2D(\n",
298
+ " (attentions): ModuleList(\n",
299
+ " (0-2): 3 x Transformer2DModel(\n",
300
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
301
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
302
+ " (transformer_blocks): ModuleList(\n",
303
+ " (0): BasicTransformerBlock(\n",
304
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
305
+ " (attn1): Attention(\n",
306
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
307
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
308
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
309
+ " (to_out): ModuleList(\n",
310
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
311
+ " (1): Dropout(p=0.0, inplace=False)\n",
312
+ " )\n",
313
+ " )\n",
314
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
315
+ " (attn2): Attention(\n",
316
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
317
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
318
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
319
+ " (to_out): ModuleList(\n",
320
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
321
+ " (1): Dropout(p=0.0, inplace=False)\n",
322
+ " )\n",
323
+ " )\n",
324
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
325
+ " (ff): FeedForward(\n",
326
+ " (net): ModuleList(\n",
327
+ " (0): GEGLU(\n",
328
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
329
+ " )\n",
330
+ " (1): Dropout(p=0.0, inplace=False)\n",
331
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
332
+ " )\n",
333
+ " )\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
337
  " )\n",
338
  " )\n",
339
  " (resnets): ModuleList(\n",
340
  " (0-1): 2 x ResnetBlock2D(\n",
341
+ " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n",
342
+ " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
343
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
344
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
345
  " (dropout): Dropout(p=0.0, inplace=False)\n",
346
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
347
  " (nonlinearity): SiLU()\n",
348
+ " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
349
  " )\n",
350
  " (2): ResnetBlock2D(\n",
351
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
352
+ " (conv1): Conv2d(1536, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
353
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
354
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
355
  " (dropout): Dropout(p=0.0, inplace=False)\n",
356
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
357
  " (nonlinearity): SiLU()\n",
358
+ " (conv_shortcut): Conv2d(1536, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
359
  " )\n",
360
  " )\n",
361
  " (upsamplers): ModuleList(\n",
362
  " (0): Upsample2D(\n",
363
+ " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
364
  " )\n",
365
  " )\n",
366
  " )\n",
367
+ " (2): CrossAttnUpBlock2D(\n",
368
  " (attentions): ModuleList(\n",
369
  " (0-2): 3 x Transformer2DModel(\n",
370
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
371
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
372
  " (transformer_blocks): ModuleList(\n",
373
  " (0): BasicTransformerBlock(\n",
374
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
375
  " (attn1): Attention(\n",
376
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
377
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
378
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
379
  " (to_out): ModuleList(\n",
380
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
381
  " (1): Dropout(p=0.0, inplace=False)\n",
382
  " )\n",
383
  " )\n",
384
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
385
  " (attn2): Attention(\n",
386
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
387
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
388
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
389
  " (to_out): ModuleList(\n",
390
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
391
  " (1): Dropout(p=0.0, inplace=False)\n",
392
  " )\n",
393
  " )\n",
394
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
395
  " (ff): FeedForward(\n",
396
  " (net): ModuleList(\n",
397
  " (0): GEGLU(\n",
398
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
399
  " )\n",
400
  " (1): Dropout(p=0.0, inplace=False)\n",
401
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
402
  " )\n",
403
  " )\n",
404
  " )\n",
405
  " )\n",
406
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
407
  " )\n",
408
  " )\n",
409
  " (resnets): ModuleList(\n",
410
  " (0): ResnetBlock2D(\n",
411
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
412
+ " (conv1): Conv2d(1536, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
413
+ " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n",
414
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
415
  " (dropout): Dropout(p=0.0, inplace=False)\n",
416
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
417
  " (nonlinearity): SiLU()\n",
418
+ " (conv_shortcut): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1))\n",
419
  " )\n",
420
  " (1): ResnetBlock2D(\n",
421
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
422
+ " (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
423
+ " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n",
424
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
425
  " (dropout): Dropout(p=0.0, inplace=False)\n",
426
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
427
  " (nonlinearity): SiLU()\n",
428
+ " (conv_shortcut): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))\n",
429
  " )\n",
430
  " (2): ResnetBlock2D(\n",
431
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
432
+ " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
433
+ " (time_emb_proj): Linear(in_features=1024, out_features=512, bias=True)\n",
434
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
435
  " (dropout): Dropout(p=0.0, inplace=False)\n",
436
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
437
  " (nonlinearity): SiLU()\n",
438
+ " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n",
439
  " )\n",
440
  " )\n",
441
  " (upsamplers): ModuleList(\n",
442
  " (0): Upsample2D(\n",
443
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
444
  " )\n",
445
  " )\n",
446
  " )\n",
447
+ " (3): UpBlock2D(\n",
448
  " (resnets): ModuleList(\n",
449
  " (0): ResnetBlock2D(\n",
450
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
451
+ " (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
452
+ " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
453
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
454
  " (dropout): Dropout(p=0.0, inplace=False)\n",
455
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
456
  " (nonlinearity): SiLU()\n",
457
+ " (conv_shortcut): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1))\n",
458
  " )\n",
459
  " (1-2): 2 x ResnetBlock2D(\n",
460
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
461
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
462
+ " (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
463
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
464
  " (dropout): Dropout(p=0.0, inplace=False)\n",
465
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
466
  " (nonlinearity): SiLU()\n",
467
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
468
  " )\n",
469
  " )\n",
470
  " )\n",
 
472
  " (mid_block): UNetMidBlock2DCrossAttn(\n",
473
  " (attentions): ModuleList(\n",
474
  " (0): Transformer2DModel(\n",
475
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
476
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
477
  " (transformer_blocks): ModuleList(\n",
478
  " (0-7): 8 x BasicTransformerBlock(\n",
479
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
480
  " (attn1): Attention(\n",
481
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
482
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
483
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
484
  " (to_out): ModuleList(\n",
485
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
486
  " (1): Dropout(p=0.0, inplace=False)\n",
487
  " )\n",
488
  " )\n",
489
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
490
  " (attn2): Attention(\n",
491
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
492
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
493
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
494
  " (to_out): ModuleList(\n",
495
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
496
  " (1): Dropout(p=0.0, inplace=False)\n",
497
  " )\n",
498
  " )\n",
499
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
500
  " (ff): FeedForward(\n",
501
  " (net): ModuleList(\n",
502
  " (0): GEGLU(\n",
503
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
504
  " )\n",
505
  " (1): Dropout(p=0.0, inplace=False)\n",
506
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
507
  " )\n",
508
  " )\n",
509
  " )\n",
510
  " )\n",
511
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
512
  " )\n",
513
  " )\n",
514
  " (resnets): ModuleList(\n",
515
  " (0-1): 2 x ResnetBlock2D(\n",
516
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
517
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
518
+ " (time_emb_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
519
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
520
  " (dropout): Dropout(p=0.0, inplace=False)\n",
521
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
522
  " (nonlinearity): SiLU()\n",
523
  " )\n",
524
  " )\n",
525
  " )\n",
526
+ " (conv_norm_out): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
527
  " (conv_act): SiLU()\n",
528
+ " (conv_out): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
529
  ")\n"
530
  ]
531
  }
 
533
  "source": [
534
  "config_sdxs = {\n",
535
  " # === Основные размеры и каналы ===\n",
536
+ " \"in_channels\": 16, # Количество входных каналов (совместимость с VAE)\n",
537
+ " \"out_channels\": 16, # Количество выходных каналов (симметрично in_channels) \n",
538
  "\n",
539
  " # === Cross-Attention ===\n",
540
+ " \"cross_attention_dim\": 1024, # Размерность текстовых эмбеддингов\n",
541
  " \"use_linear_projection\": True,\n",
542
  " \"norm_num_groups\": 32,\n",
543
  " \n",
 
546
  " \"DownBlock2D\",\n",
547
  " \"CrossAttnDownBlock2D\",\n",
548
  " \"CrossAttnDownBlock2D\",\n",
549
+ " \"CrossAttnDownBlock2D\",\n",
550
  " ],\n",
551
  " \"up_block_types\": [ # декодер\n",
552
+ " \"CrossAttnUpBlock2D\",\n",
553
  " \"CrossAttnUpBlock2D\",\n",
554
  " \"CrossAttnUpBlock2D\",\n",
555
  " \"UpBlock2D\",\n",
556
  " ],\n",
557
  "\n",
558
  " # === Конфигурация каналов ===\n",
559
+ " \"block_out_channels\": [256, 512, 1024, 1024],\n",
560
  "\n",
561
+ " \"transformer_layers_per_block\": [1, 1, 1, 8],\n",
562
+ " \"attention_head_dim\": [4, 8, 16, 16],\n",
563
  "}\n",
564
  "\n",
565
  "def check_initialization(model):\n",
 
584
  " print(f\"Количество параметров: {num_params}\")\n",
585
  "\n",
586
  " # Генерация тестового латента (640x512 в latent space)\n",
587
+ " test_latent = torch.randn(1, 16, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
588
  " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
589
+ " encoder_hidden_states = torch.randn(1, 77, 1024).to(\"cuda\", dtype=torch.float16)\n",
590
  " \n",
591
  " with torch.no_grad():\n",
592
  " output = new_unet(\n",
 
625
  "name": "python",
626
  "nbconvert_exporter": "python",
627
  "pygments_lexer": "ipython3",
628
+ "version": "3.11.11"
629
  }
630
  },
631
  "nbformat": 4,
train.py CHANGED
@@ -31,10 +31,10 @@ project = "unet"
31
  batch_size = 16
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
- num_epochs = 30
35
  # samples/save per epoch
36
  sample_interval_share = 1
37
- use_wandb = False
38
  save_model = True
39
  use_decay = True
40
  fbp = False # fused backward pass
@@ -89,10 +89,10 @@ if fixed_seed:
89
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
  # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
  loss_ratios = {
92
- "mse": 0.60,
93
- "mae": 0.35,
94
  "huber": 0.0,
95
- "dispersive": 0.05,
96
  }
97
  median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
 
@@ -110,7 +110,7 @@ def sample_timesteps_bias(
110
  num_train_timesteps: int, # обычно 1000
111
  steps_offset: int = 0,
112
  device=None,
113
- mode: str = "beta", # "beta", "uniform"
114
  ) -> torch.Tensor:
115
  """
116
  Возвращает timesteps с разным bias:
@@ -241,7 +241,7 @@ gen.manual_seed(seed)
241
  # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
  # torch_dtype=dtype
243
  # ).to(device="cpu").eval()
244
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", subfolder=None,torch_dtype=dtype).to(device).eval()
245
 
246
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
  if shift_factor is None:
 
31
  batch_size = 16
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
+ num_epochs = 300
35
  # samples/save per epoch
36
  sample_interval_share = 1
37
+ use_wandb = True
38
  save_model = True
39
  use_decay = True
40
  fbp = False # fused backward pass
 
89
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
  # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
  loss_ratios = {
92
+ "mse": 1.0,
93
+ "mae": 0.0,
94
  "huber": 0.0,
95
+ "dispersive": 0.0,
96
  }
97
  median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
 
 
110
  num_train_timesteps: int, # обычно 1000
111
  steps_offset: int = 0,
112
  device=None,
113
+ mode: str = "uniform", # "beta", "uniform"
114
  ) -> torch.Tensor:
115
  """
116
  Возвращает timesteps с разным bias:
 
241
  # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
  # torch_dtype=dtype
243
  # ).to(device="cpu").eval()
244
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
245
 
246
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
  if shift_factor is None:
unet/config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.35.1",
4
+ "_name_or_path": "unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": [
10
+ 4,
11
+ 8,
12
+ 16,
13
+ 16
14
+ ],
15
+ "attention_type": "default",
16
+ "block_out_channels": [
17
+ 256,
18
+ 512,
19
+ 1024,
20
+ 1024
21
+ ],
22
+ "center_input_sample": false,
23
+ "class_embed_type": null,
24
+ "class_embeddings_concat": false,
25
+ "conv_in_kernel": 3,
26
+ "conv_out_kernel": 3,
27
+ "cross_attention_dim": 1024,
28
+ "cross_attention_norm": null,
29
+ "down_block_types": [
30
+ "DownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "CrossAttnDownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "dropout": 0.0,
37
+ "dual_cross_attention": false,
38
+ "encoder_hid_dim": null,
39
+ "encoder_hid_dim_type": null,
40
+ "flip_sin_to_cos": true,
41
+ "freq_shift": 0,
42
+ "in_channels": 16,
43
+ "layers_per_block": 2,
44
+ "mid_block_only_cross_attention": null,
45
+ "mid_block_scale_factor": 1,
46
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
47
+ "norm_eps": 1e-05,
48
+ "norm_num_groups": 32,
49
+ "num_attention_heads": null,
50
+ "num_class_embeds": null,
51
+ "only_cross_attention": false,
52
+ "out_channels": 16,
53
+ "projection_class_embeddings_input_dim": null,
54
+ "resnet_out_scale_factor": 1.0,
55
+ "resnet_skip_time_act": false,
56
+ "resnet_time_scale_shift": "default",
57
+ "reverse_transformer_layers_per_block": null,
58
+ "sample_size": null,
59
+ "time_cond_proj_dim": null,
60
+ "time_embedding_act_fn": null,
61
+ "time_embedding_dim": null,
62
+ "time_embedding_type": "positional",
63
+ "timestep_post_act": null,
64
+ "transformer_layers_per_block": [
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 8
69
+ ],
70
+ "up_block_types": [
71
+ "CrossAttnUpBlock2D",
72
+ "CrossAttnUpBlock2D",
73
+ "CrossAttnUpBlock2D",
74
+ "UpBlock2D"
75
+ ],
76
+ "upcast_attention": false,
77
+ "use_linear_projection": true
78
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74a909271318a4d576b1519be1697f2d7989534c89fa4b5ae0f7a7fdd04a9245
3
+ size 6184944280