Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
ba84478
·
1 Parent(s): 945f661
dataset.py CHANGED
@@ -18,12 +18,12 @@ from tqdm import tqdm
18
  from datetime import timedelta
19
 
20
  # ---------------- 1️⃣ Настройки ----------------
21
- dtype = torch.float32
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  batch_size = 5
24
- min_size = 320 #384 #320 #192 #256 #192
25
- max_size = 640 #768 #640 #384 #256 #384
26
- step = 32 #64
27
  empty_share = 0.0
28
  limit = 0
29
  # Основная процедура обработки
@@ -43,20 +43,9 @@ def clear_cuda_memory():
43
  def load_models():
44
  print("Загрузка моделей...")
45
  #vae = AsymmetricAutoencoderKL.from_pretrained("AiArtLab/sdxs-1b",subfolder="vae",torch_dtype=dtype).to(device).eval()
46
- vae = AutoencoderKLFlux2.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
47
-
48
- #model_name = "Qwen/Qwen3-0.6B"
49
- #tokenizer = AutoTokenizer.from_pretrained(model_name)
50
- #model = AutoModelForCausalLM.from_pretrained(
51
- # model_name,
52
- # torch_dtype=dtype,
53
- # device_map=device
54
- #).eval()
55
- #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
56
- #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
57
- return vae#, model, tokenizer
58
-
59
- #vae, model, tokenizer = load_models()
60
  vae = load_models()
61
 
62
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
@@ -67,8 +56,11 @@ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
67
  if scaling_factor is None:
68
  scaling_factor = 1.0
69
 
70
- latents_mean = getattr(vae.config, "latents_mean", None)
71
- latents_std = getattr(vae.config, "latents_std", None)
 
 
 
72
 
73
  # ---------------- 3️⃣ Трансформации ----------------
74
  def get_image_transform(min_size=256, max_size=512, step=64):
@@ -126,50 +118,6 @@ def get_image_transform(min_size=256, max_size=512, step=64):
126
  return transform
127
 
128
  # ---------------- 4️⃣ Функции обработки ----------------
129
- def last_token_pool(last_hidden_states: torch.Tensor,
130
- attention_mask: torch.Tensor) -> torch.Tensor:
131
- # Определяем, есть ли left padding
132
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
133
- if left_padding:
134
- return last_hidden_states[:, -1]
135
- else:
136
- sequence_lengths = attention_mask.sum(dim=1) - 1
137
- batch_size = last_hidden_states.shape[0]
138
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
139
-
140
- def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
141
- with torch.inference_mode():
142
- # Токенизация
143
- batch = tokenizer(
144
- texts,
145
- return_tensors="pt",
146
- padding="max_length",
147
- truncation=True,
148
- max_length=max_length
149
- ).to(device)
150
-
151
- # Прогон через модель
152
- #outputs = model(**batch)
153
-
154
- # Пулинг по last token
155
- #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
156
-
157
- # L2-нормализация (опционально, обычно нужна для семантического поиска)
158
- #if normalize:
159
- # embeddings = F.normalize(embeddings, p=2, dim=1)
160
-
161
- # Прогон через базовую модель (внутри CausalLM)
162
- outputs = model.model(**batch, output_hidden_states=True)
163
-
164
- # Берем последний слой (эмбеддинги всех токенов)
165
- hidden_states = outputs.hidden_states[-1] # [B, L, D]
166
-
167
- # Можно применить нормализацию по каждому токену (как в CLIP)
168
- if normalize:
169
- hidden_states = F.normalize(hidden_states, p=2, dim=-1)
170
-
171
- return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
172
-
173
  def clean_label(label):
174
  label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
175
  if label.startswith("."):
@@ -200,42 +148,6 @@ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
200
 
201
  return labels_for_model, labels_for_logging
202
 
203
- def _patchify_latents(latents):
204
- batch_size, num_channels_latents, height, width = latents.shape
205
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
206
- latents = latents.permute(0, 1, 3, 5, 2, 4)
207
- latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
208
- return latents
209
-
210
- @staticmethod
211
- def _unpatchify_latents(latents):
212
- batch_size, num_channels_latents, height, width = latents.shape
213
- latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
214
- latents = latents.permute(0, 1, 4, 2, 5, 3)
215
- latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
216
- return latents
217
-
218
- def flux_encode(vae,latents):
219
- # patch
220
- image_latents = _patchify_latents(latents)
221
- # norm
222
- latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
223
- latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
224
- latents = (image_latents - latents_bn_mean) / latents_bn_std
225
- # unpatch
226
- latents = _unpatchify_latents(latents)
227
- return latents
228
-
229
- def flux_decode(vae,latents):
230
- # patch
231
- image_latents = _patchify_latents(latents)
232
- # norm
233
- latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
234
- latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
235
- latents = image_latents * latents_bn_std + latents_bn_mean
236
- # unpatch
237
- latents = _unpatchify_latents(latents)
238
- return latents
239
 
240
  def encode_to_latents(images, texts):
241
  transform = get_image_transform(min_size, max_size, step)
@@ -269,20 +181,19 @@ def encode_to_latents(images, texts):
269
  # Кодируем батч
270
  with torch.no_grad():
271
  posteriors = vae.encode(batch_tensor).latent_dist.mode()
272
- latents = (posteriors - shift_factor) / scaling_factor
273
- image_latents = flux_encode(vae, latents)
 
274
 
275
- latents_np = image_latents.to(dtype).cpu().numpy()
276
 
277
  # Обрабатываем тексты
278
  text_labels = [clean_label(text) for text in texts]
279
 
280
  model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
281
- #embeddings = encode_texts_batch(model_prompts, tokenizer, model)
282
 
283
  return {
284
  "vae": latents_np,
285
- #"embeddings": embeddings,
286
  "text": text_labels,
287
  "width": widths,
288
  "height": heights
 
18
  from datetime import timedelta
19
 
20
  # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float16
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  batch_size = 5
24
+ min_size = 640 #320 #384 #320 #192 #256 #192
25
+ max_size = 1280 #640 #768 #640 #384 #256 #384
26
+ step = 64
27
  empty_share = 0.0
28
  limit = 0
29
  # Основная процедура обработки
 
43
  def load_models():
44
  print("Загрузка моделей...")
45
  #vae = AsymmetricAutoencoderKL.from_pretrained("AiArtLab/sdxs-1b",subfolder="vae",torch_dtype=dtype).to(device).eval()
46
+ vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
47
+ return vae
48
+
 
 
 
 
 
 
 
 
 
 
 
49
  vae = load_models()
50
 
51
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
 
56
  if scaling_factor is None:
57
  scaling_factor = 1.0
58
 
59
+ mean = getattr(vae.config, "latents_mean", None)
60
+ std = getattr(vae.config, "latents_std", None)
61
+ if mean is not None and std is not None:
62
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
63
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
64
 
65
  # ---------------- 3️⃣ Трансформации ----------------
66
  def get_image_transform(min_size=256, max_size=512, step=64):
 
118
  return transform
119
 
120
  # ---------------- 4️⃣ Функции обработки ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def clean_label(label):
122
  label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
123
  if label.startswith("."):
 
148
 
149
  return labels_for_model, labels_for_logging
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def encode_to_latents(images, texts):
153
  transform = get_image_transform(min_size, max_size, step)
 
181
  # Кодируем батч
182
  with torch.no_grad():
183
  posteriors = vae.encode(batch_tensor).latent_dist.mode()
184
+ if latents_mean is not None and latents_std is not None:
185
+ posteriors = (posteriors - latents_mean) / latents_std
186
+ posteriors = (posteriors - shift_factor) / scaling_factor
187
 
188
+ latents_np = posteriors.to(dtype).cpu().numpy()
189
 
190
  # Обрабатываем тексты
191
  text_labels = [clean_label(text) for text in texts]
192
 
193
  model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
 
194
 
195
  return {
196
  "vae": latents_np,
 
197
  "text": text_labels,
198
  "width": widths,
199
  "height": heights
samples/unet_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: bcbc2dac142fee70b8410f32a06848c88c7a2201cea2e4fd7aa0a0f882efde51
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB

Git LFS Details

  • SHA256: 03cc88423078a907c4122d71a793fca84e27203a28982fd7c84fe8cba40db89a
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/unet_352x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 01e2cbd2dc1d22131507e49321786caed78d94e9e89f5ee658369f272e4c5d40
  • Pointer size: 130 Bytes
  • Size of remote file: 81.2 kB

Git LFS Details

  • SHA256: d7865b7970297f724bca7466a2d7135333bda43f8e27cc69d7d179f53c1c0ca3
  • Pointer size: 130 Bytes
  • Size of remote file: 78.6 kB
samples/unet_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: f6bcb53f5afd25a95dda33aaa5dffd0017252fa63bce9e57c9ec6d2db85143fe
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB

Git LFS Details

  • SHA256: df790034bdf660c3fbf173adc7764aa448f429dd1f3e9193c601dc95d21b8ac5
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
samples/unet_416x640_0.jpg CHANGED

Git LFS Details

  • SHA256: accacb0e6dff38f6e342af99054cf3efd81d12435ad95ce3aefa7c40f065a5d4
  • Pointer size: 130 Bytes
  • Size of remote file: 94.6 kB

Git LFS Details

  • SHA256: 6fc2c668b158c4b69b4fcc7854158a757fb04876f84ac001216e47fa4ec80eb4
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/unet_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 73fd8f61cd584597548d14bbdac482a68789e9a90a489a836a73d7a920d1637a
  • Pointer size: 130 Bytes
  • Size of remote file: 90.1 kB

Git LFS Details

  • SHA256: 2e077454ad5299f406ca3c6f889a88f317b06c0ab8c5e2b98216f3ae3f5ed6a1
  • Pointer size: 130 Bytes
  • Size of remote file: 91.9 kB
samples/unet_480x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 13d12501f19aed7ba64fd4d2945efbb11aee147e5f5b3359f1b1e848422716a3
  • Pointer size: 130 Bytes
  • Size of remote file: 90 kB

Git LFS Details

  • SHA256: 6f9f51cf83614b85c7726f2f54373e4affe65327d983c842dc0a3b696d3ccd9a
  • Pointer size: 130 Bytes
  • Size of remote file: 88.4 kB
samples/unet_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: b3e5f7e8680d144fbcb1846d8d6766d5f18aa85d3e6c3dc63ebf9bcdc3046a9c
  • Pointer size: 130 Bytes
  • Size of remote file: 69.8 kB

Git LFS Details

  • SHA256: db884e6eb0f9b014faee5571e698ebb20089b3e93c508e4e11beaed06bbe6b05
  • Pointer size: 130 Bytes
  • Size of remote file: 65.6 kB
samples/unet_544x640_0.jpg CHANGED

Git LFS Details

  • SHA256: dab77ce8cb5b751c9d7f1e2f77044198f25597feef8eec0d21f0c481c5746956
  • Pointer size: 130 Bytes
  • Size of remote file: 93.7 kB

Git LFS Details

  • SHA256: 161beac686b680f321044aebf6dd2e0f93eab152571ff4ab293c9ee177a07adb
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
samples/unet_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 7c728b324e2df2f8418220bcfb31828443060a062cd3d69df757b9e1f562cc7b
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB

Git LFS Details

  • SHA256: 5fa0ae0428ee4cd61c969d4898434f6e84431fec4bfb81d360ab20ee5fd67249
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
samples/unet_608x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 4973b9fa85874b0d2740002be9fea0f69457bb669264c76e8c495a2ba017d61e
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB

Git LFS Details

  • SHA256: 6616203f078db20a026c0e02ef77740a61dffc1b13263ada31ef1ea331ab650f
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
samples/unet_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: c8657be58aa12c6615e71420113bb8c102203ad86e276be8e87170406b83fbd3
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB

Git LFS Details

  • SHA256: 580c7abbbef326e747cf321498828febacc197ef42e29963c20b82152e5f7706
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
samples/unet_640x352_0.jpg CHANGED

Git LFS Details

  • SHA256: 6bc45d6cd5ecd49db8db1ac99dd715607036391aabe213d4f58085566eba4427
  • Pointer size: 130 Bytes
  • Size of remote file: 66.5 kB

Git LFS Details

  • SHA256: ab05ac79672aa109b56bccef3a65560909dc810c5de288ba732f2401dfd48978
  • Pointer size: 130 Bytes
  • Size of remote file: 68.4 kB
samples/unet_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 8f9c41a3a2d4d224c6929544a8a337b1b73e47ffd7d796fe888cb8293685a4d3
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB

Git LFS Details

  • SHA256: 5bea77decad66d7b604e929dc205670103a4f4973164baa41e3ab4706a25059a
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
samples/unet_640x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 25cc0268be3d1c02cae62abeaed0355732e6b290916ee11fb3b32812a51bcd5c
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB

Git LFS Details

  • SHA256: 7c2219b094804419481d10f42f5ccc81f3bb5f41a4b4a5dfaa15d0754c8834dc
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
samples/unet_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 4706d4591d9e9c9e01151ca2f25706f672c0f232e06358c1179f3de0b071da6e
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB

Git LFS Details

  • SHA256: 0fa17972fcba861a1a178c78a61a91bd31bfcad8d9aaa702473064ddf6ffe01d
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
samples/unet_640x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 5f9c68dcac2b3045e9e5ce90268f0f68e7a763682b794606806c2f95525d386c
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB

Git LFS Details

  • SHA256: bd15fbd036521d12be0082ccfcec180724936d8957c69bf020a1a9b7b3aab084
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
samples/unet_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: a60dfafcb4fc62ec61be0c8cb486702dc306c111b78fbe6e9bee39b4e7095014
  • Pointer size: 130 Bytes
  • Size of remote file: 94 kB

Git LFS Details

  • SHA256: c619dbf62136f96e77542f78aad7d4aa086395297c638410a985cb50adab8383
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/unet_640x544_0.jpg CHANGED

Git LFS Details

  • SHA256: b7c93a24167f110f9959d5afdc9feaacce67e59dc7134085534b9681b8bdb004
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB

Git LFS Details

  • SHA256: 9e9a8358c5b7a2b31ead9bef8950cf7b149243ff50641a5e3aebf1d193aef335
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
samples/unet_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 9756fcbc2e3d4fa786f8456b54c6a9418f63b1c77f01a272c163facb8a577380
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB

Git LFS Details

  • SHA256: a08d6de28b8337b5b5425086462409b13885560958ad2ed7a69fbce5256cb43f
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
samples/unet_640x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 07a4105f73c8c984eb5da2a25fdaeb813052b0c82bcd33dad932b17cdf4dcea4
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB

Git LFS Details

  • SHA256: 469eda64aa4353c24a418a1e910ccc197648dc9295c134e87a784fbe6a474e0f
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
samples/unet_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: b9a9143bd751c88442514e755a4840ff4a7c5de51d17d9ffe5b787c5987c2a2e
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB

Git LFS Details

  • SHA256: b41a9b688b4054da5caa507c216e0ab41d11fef71c8a3c03d22a0d47bc15a8e6
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:491d2fb7582ad01f0c930c28f58476cf8f46e0900417c56ce3a06e8f63eb19df
3
  size 5946605448
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12e68807a96fc1f2dd993c400888104240383fd74379a2f083952996c0d60c13
3
  size 5946605448