Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
845c08d
·
1 Parent(s): a4e4f02
girl.jpg CHANGED

Git LFS Details

  • SHA256: 19b70dfc2cf31fd200c9653e218849c26ab02daa8bbfd3f9ed2e31d0b20fce83
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB

Git LFS Details

  • SHA256: 01587e489b357ac5bcfd46afaae609153b358a7626d03e70189c47e25330e733
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 11ea40ccb6db120c1ff74757e311c346d7f9343ddf7665c647ebe02238c843a3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.63 MB

Git LFS Details

  • SHA256: 9abb8ef26aaa20fd3756a0c03bc64c58810aa15dc855da7265d240fcbbfbf359
  • Pointer size: 132 Bytes
  • Size of remote file: 2.69 MB
model_index.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b7b617daaaefac820f302b5222a5cbc9aeb7b926c676b01e67fa3924ee95bdc6
3
- size 557
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1b6ce208190165975458f502eb3e7ad6d4a5dd54a507f8dd727e636e363ae93
3
+ size 428
pipeline_sdxs-Copy1.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List, Union, Optional, Tuple
5
+ from dataclasses import dataclass
6
+
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.utils import BaseOutput
9
+ from tqdm import tqdm
10
+ from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
11
+
12
+ @dataclass
13
+ class SdxsPipelineOutput(BaseOutput):
14
+ images: Union[List[Image.Image], np.ndarray]
15
+ prompt: Optional[Union[str, List[str]]] = None # Возврат улучшенного промпта
16
+
17
+ class SdxsPipeline(DiffusionPipeline):
18
+ def __init__(self, vae, text_encoder, text_encoder2, tokenizer, tokenizer2, unet, scheduler):
19
+ super().__init__()
20
+ self.register_modules(
21
+ vae=vae,
22
+ text_encoder=text_encoder,
23
+ text_encoder2=text_encoder2,
24
+ tokenizer=tokenizer,
25
+ tokenizer2=tokenizer2,
26
+ unet=unet,
27
+ scheduler=scheduler
28
+ )
29
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
30
+ mean = getattr(self.vae.config, "latents_mean", None)
31
+ std = getattr(self.vae.config, "latents_std", None)
32
+ if mean is not None and std is not None:
33
+ self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
34
+ self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
35
+
36
+ def preprocess_image(self, image: Image.Image, width: int, height: int):
37
+ """Ресайз и центрированный кроп изображения для асимметричного VAE."""
38
+ # Для энкодера с масштабом 8
39
+ target_height = ((height // self.vae_scale_factor) * self.vae_scale_factor)
40
+ target_width = ((width // self.vae_scale_factor) * self.vae_scale_factor)
41
+
42
+ w, h = image.size
43
+ aspect_ratio = target_width / target_height
44
+
45
+ if w / h > aspect_ratio:
46
+ new_w = int(h * aspect_ratio)
47
+ left = (w - new_w) // 2
48
+ image = image.crop((left, 0, left + new_w, h))
49
+ else:
50
+ new_h = int(w / aspect_ratio)
51
+ top = (h - new_h) // 2
52
+ image = image.crop((0, top, w, top + new_h))
53
+
54
+ image = image.resize((target_width, target_height), resample=Image.LANCZOS)
55
+ image = np.array(image).astype(np.float32) / 255.0
56
+ image = image[None].transpose(0, 3, 1, 2) # [1, C, H, W]
57
+ image = torch.from_numpy(image)
58
+ return 2.0 * image - 1.0 # [-1, 1]
59
+
60
+
61
+ def encode_prompt(self, prompt, negative_prompt, device, dtype):
62
+ def get_single_encode(texts):
63
+ if not texts:
64
+ texts = [""]
65
+ elif isinstance(texts, str):
66
+ texts = [texts]
67
+
68
+ with torch.no_grad():
69
+ toks = self.tokenizer(
70
+ texts,
71
+ padding="max_length",
72
+ max_length=self.text_encoder.config.max_position_embeddings,
73
+ truncation=True,
74
+ return_tensors="pt"
75
+ ).to(device)
76
+
77
+ outputs = self.text_encoder(
78
+ input_ids=toks.input_ids,
79
+ attention_mask=toks.attention_mask,
80
+ output_hidden_states=True
81
+ )
82
+
83
+ # 1. Берем -2 слой [Batch, Seq, Dim]
84
+ hidden = outputs.hidden_states[-2]
85
+
86
+ # 2. Достаем pooled вектор (последний токен) [Batch, Dim]
87
+ seq_lens = toks.attention_mask.sum(dim=1) - 1
88
+ pooled = hidden[torch.arange(hidden.shape[0]), seq_lens.clamp(min=0)]
89
+
90
+ # 3. Нормализация
91
+ norm = self.text_encoder.text_model.final_layer_norm
92
+ hidden = norm(hidden)
93
+ pooled = norm(pooled)
94
+
95
+ # 4. Объединяем в матрицу: Пулед (как 1-й токен) + остальные токены
96
+ # pooled.unsqueeze(1) делает [Batch, 1, Dim]
97
+ embeds = torch.cat([pooled.unsqueeze(1), hidden], dim=1)
98
+
99
+ # 5. Расширяем маску для нового токена (добавляем единицы спереди)
100
+ ones = torch.ones((toks.attention_mask.shape[0], 1), dtype=toks.attention_mask.dtype, device=device)
101
+ mask = torch.cat([ones, toks.attention_mask], dim=1)
102
+
103
+ return embeds, mask, pooled
104
+
105
+ def get_pooled_encode(texts):
106
+ if texts is None:
107
+ texts = ""
108
+
109
+ if isinstance(texts, str):
110
+ texts = [texts]
111
+
112
+ with torch.no_grad():
113
+ # 1. Собираем текстовые промпты оборачивая их в Chat Template
114
+ formatted_prompts = []
115
+ for t in texts:
116
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
117
+ res_text = self.tokenizer2.apply_chat_template(
118
+ messages,
119
+ add_generation_prompt=True,
120
+ tokenize=False
121
+ )
122
+ formatted_prompts.append(res_text)
123
+
124
+ # 2. Токенизируем, режем и добавляем паддинг за один раз
125
+ toks = self.tokenizer2(
126
+ formatted_prompts,
127
+ padding="max_length",
128
+ max_length=self.text_encoder.config.max_position_embeddings,
129
+ truncation=True, # Не забываем обрезать, если вдруг длиннее
130
+ return_tensors="pt"
131
+ ).to(device)
132
+
133
+ # 3. Прогоняем через модель
134
+ outputs = self.text_encoder2(
135
+ input_ids=toks.input_ids,
136
+ attention_mask=toks.attention_mask,
137
+ output_hidden_states=True
138
+ )
139
+
140
+ layer_index = -2
141
+ last_hidden = outputs.hidden_states[layer_index]
142
+ seq_len = toks.attention_mask.sum(dim=1) - 1
143
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
144
+
145
+ return pooled
146
+ def get_encode(texts):
147
+ if texts is None:
148
+ texts = ""
149
+
150
+ if isinstance(texts, str):
151
+ texts = [texts]
152
+
153
+ with torch.no_grad():
154
+ # 1. Собираем текстовые промпты оборачивая их в Chat Template
155
+ formatted_prompts = []
156
+ for t in texts:
157
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
158
+ res_text = self.tokenizer2.apply_chat_template(
159
+ messages,
160
+ add_generation_prompt=True,
161
+ tokenize=False
162
+ )
163
+ formatted_prompts.append(res_text)
164
+
165
+ # 2. Токенизируем, режем и добавляем паддинг за один раз
166
+ toks = self.tokenizer2(
167
+ formatted_prompts,
168
+ padding="max_length",
169
+ max_length=self.text_encoder.config.max_position_embeddings,
170
+ truncation=True, # Не забываем обрезать, если вдруг длиннее
171
+ return_tensors="pt"
172
+ ).to(device)
173
+
174
+ # 3. Прогоняем через модель
175
+ outputs = self.text_encoder2(
176
+ input_ids=toks.input_ids,
177
+ attention_mask=toks.attention_mask,
178
+ output_hidden_states=True
179
+ )
180
+
181
+ layer_index = -2
182
+ last_hidden = outputs.hidden_states[layer_index]
183
+ seq_len = toks.attention_mask.sum(dim=1) - 1
184
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
185
+
186
+ return last_hidden, toks.attention_mask, pooled
187
+
188
+ #pos_embeds, pos_mask, pooled_pos = get_single_encode(prompt)
189
+ #neg_embeds, neg_mask, pooled_neg = get_single_encode(negative_prompt)
190
+ # 768 + 2048
191
+ #pos_pooled = get_pooled_encode(prompt) #torch.cat([pooled_pos, get_pooled_encode(prompt)], dim=1)
192
+ #neg_pooled = get_pooled_encode(negative_prompt) #torch.cat([pooled_neg, get_pooled_encode(negative_prompt)], dim=1)
193
+ pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
194
+ neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
195
+
196
+ batch_size = pos_embeds.shape[0]
197
+ if neg_embeds.shape[0] != batch_size:
198
+ neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
199
+ neg_mask = neg_mask.repeat(batch_size, 1)
200
+ neg_pooled = neg_pooled.repeat(batch_size, 1)
201
+
202
+ if pos_pooled.shape[0] != batch_size:
203
+ pos_pooled = pos_pooled.repeat(batch_size, 1)
204
+
205
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
206
+ final_mask = torch.cat([neg_mask, pos_mask], dim=0)
207
+ pooled_embeds = torch.cat([neg_pooled, pos_pooled], dim=0)
208
+
209
+ return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64), pooled_embeds.to(dtype=dtype)
210
+
211
+ @torch.no_grad()
212
+ def __call__(
213
+ self,
214
+ prompt: Union[str, List[str]],
215
+ image: Optional[Union[Image.Image, List[Image.Image]]] = None,
216
+ coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
217
+ negative_prompt: Optional[Union[str, List[str]]] = None,
218
+ height: int = 1024,
219
+ width: int = 1024,
220
+ num_inference_steps: int = 40,
221
+ guidance_scale: float = 4.0,
222
+ generator: Optional[torch.Generator] = None,
223
+ seed: Optional[int] = None,
224
+ output_type: str = "pil",
225
+ return_dict: bool = True,
226
+ refine_prompt: bool = False, # Флаг рефайна!
227
+ # structure_preservation оставляем для совместимости, но теперь он почти не нужен
228
+ structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
229
+ **kwargs,
230
+ ):
231
+ device = self.device
232
+ dtype = self.unet.dtype
233
+
234
+ if generator is None and seed is not None:
235
+ generator = torch.Generator(device=device).manual_seed(seed)
236
+
237
+ # ==================== REFINE PROMPT (INLINE) ====================
238
+ if refine_prompt and prompt:
239
+ sys_msg = (
240
+ "You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
241
+ "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
242
+ "Output **only** the final revised prompt in **English**, with absolutely no commentary.\n Don't use cliches like warm,soft,vibrant, wildflowers. Be creative "
243
+ "User input prompt: "
244
+ )
245
+ prompts_list = [prompt] if isinstance(prompt, str) else prompt
246
+ refined_list = []
247
+
248
+ for p in prompts_list:
249
+ messages = [{"role": "user", "content": [{"type": "text", "text": sys_msg + p}]}]
250
+
251
+ # Используем Qwen-Instruct формат (apply_chat_template сам подставит system/user/assistant токены)
252
+ inputs = self.tokenizer2.apply_chat_template(
253
+ messages,
254
+ tokenize=True,
255
+ add_generation_prompt=True,
256
+ return_dict=True,
257
+ return_tensors="pt"
258
+ ).to(device)
259
+
260
+ generated_ids = self.text_encoder2.generate(
261
+ **inputs, max_new_tokens=self.text_encoder.config.max_position_embeddings, do_sample=True,temperature = 0.7
262
+ )
263
+
264
+ # Обрезаем входные токены из ответа
265
+ generated_ids_trimmed = [
266
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
267
+ ]
268
+ output_text = self.tokenizer2.batch_decode(
269
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
270
+ )
271
+ refined_list.append(output_text)
272
+
273
+ prompt = refined_list[0] if isinstance(prompt, str) else refined_list
274
+
275
+ # ==================== ENCODE PROMPTS ====================
276
+ text_embeddings, attention_mask, pooled_embeds = self.encode_prompt(
277
+ prompt, negative_prompt, device, dtype
278
+ )
279
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
280
+
281
+ # 2. Scheduler timesteps
282
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
283
+ timesteps = self.scheduler.timesteps
284
+
285
+ # ==================== TIME IDS =======================================
286
+ # time_ids должен иметь ТОТ ЖЕ batch-размер, что и pooled_embeds и text_embeddings
287
+ # (в твоём encode_prompt они всегда удваиваются из-за CFG)
288
+ time_ids = torch.zeros(
289
+ pooled_embeds.shape[0], # ← вот это главное
290
+ 6,
291
+ device=device,
292
+ dtype=torch.long
293
+ )
294
+
295
+ # ==================== IMG2IMG БЛОК (НОВАЯ ВЕРСИЯ) ====================
296
+ if image is not None:
297
+ # --- Подготовка изображения ---
298
+ if isinstance(image, Image.Image):
299
+ image_tensor = self.preprocess_image(image, width, height).to(device, self.vae.dtype)
300
+ else:
301
+ image_tensor = self.preprocess_image(image[0], width, height).to(device, self.vae.dtype)
302
+
303
+ # --- Кодируем в latent ---
304
+ latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
305
+ latents_clean = (latents_clean - self.vae_latents_mean.to(device, self.vae.dtype)) / self.vae_latents_std.to(device, self.vae.dtype)
306
+ latents_clean = latents_clean.to(dtype)
307
+
308
+ # --- Добавляем шум по Rectified Flow формуле ---
309
+ noise = torch.randn_like(latents_clean)
310
+
311
+ # coef = strength (0.0 → оригинал, 1.0 → чистый шум)
312
+ sigma = coef # в Flow Matching sigma = t
313
+ if hasattr(self.scheduler, "sigma_shift"): # если есть shift (Flux-style)
314
+ sigma = self.scheduler.sigma_shift(sigma)
315
+
316
+ latents = (1.0 - sigma) * latents_clean + sigma * noise
317
+
318
+ # Обрезаем timesteps начиная с текущего sigma
319
+ init_timestep = int(num_inference_steps * coef)
320
+ t_start = max(num_inference_steps - init_timestep, 0)
321
+ timesteps = timesteps[t_start:]
322
+
323
+ else:
324
+ # txt2img
325
+ latent_h = height // self.vae_scale_factor
326
+ latent_w = width // self.vae_scale_factor
327
+
328
+ latents = torch.randn(
329
+ (batch_size, self.unet.config.in_channels, latent_h, latent_w),
330
+ generator=generator, device=device, dtype=dtype
331
+ )
332
+
333
+ # ==================== DENOISING LOOP (одинаковый для txt2img и img2img) ====================
334
+ for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
335
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
336
+
337
+ model_out = self.unet(
338
+ latent_model_input,
339
+ t,
340
+ encoder_hidden_states=text_embeddings,
341
+ encoder_attention_mask=attention_mask,
342
+ #added_cond_kwargs=added_cond_kwargs,
343
+ added_cond_kwargs={"text_embeds": pooled_embeds,"time_ids": time_ids},
344
+ return_dict=False,
345
+ )[0]
346
+
347
+ if guidance_scale > 1.0:
348
+ flow_uncond, flow_cond = model_out.chunk(2)
349
+ model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
350
+
351
+ # Важно: используем scheduler.step — он сам знает, что делать с velocity
352
+ latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
353
+
354
+ # ==================== DECODE ====================
355
+ if output_type == "latent":
356
+ if not return_dict: return (latents, prompt)
357
+ return SdxsPipelineOutput(images=latents, prompt=prompt)
358
+
359
+ latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
360
+ image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
361
+
362
+ image_output = (image_output.clamp(-1, 1) + 1) / 2
363
+ image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
364
+
365
+ if output_type == "pil":
366
+ images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
367
+ else:
368
+ images = image_np
369
+
370
+ if not return_dict:
371
+ return (images, prompt)
372
+ return SdxsPipelineOutput(images=images, prompt=prompt)
pipeline_sdxs.py CHANGED
@@ -7,22 +7,20 @@ from dataclasses import dataclass
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
10
- from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
11
 
12
  @dataclass
13
  class SdxsPipelineOutput(BaseOutput):
14
  images: Union[List[Image.Image], np.ndarray]
15
- prompt: Optional[Union[str, List[str]]] = None # Возврат улучшенного промпта
16
 
17
  class SdxsPipeline(DiffusionPipeline):
18
- def __init__(self, vae, text_encoder, text_encoder2, tokenizer, tokenizer2, unet, scheduler):
19
  super().__init__()
20
  self.register_modules(
21
  vae=vae,
22
  text_encoder=text_encoder,
23
- text_encoder2=text_encoder2,
24
  tokenizer=tokenizer,
25
- tokenizer2=tokenizer2,
26
  unet=unet,
27
  scheduler=scheduler
28
  )
@@ -59,90 +57,6 @@ class SdxsPipeline(DiffusionPipeline):
59
 
60
 
61
  def encode_prompt(self, prompt, negative_prompt, device, dtype):
62
- def get_single_encode(texts):
63
- if not texts:
64
- texts = [""]
65
- elif isinstance(texts, str):
66
- texts = [texts]
67
-
68
- with torch.no_grad():
69
- toks = self.tokenizer(
70
- texts,
71
- padding="max_length",
72
- max_length=self.text_encoder.config.max_position_embeddings,
73
- truncation=True,
74
- return_tensors="pt"
75
- ).to(device)
76
-
77
- outputs = self.text_encoder(
78
- input_ids=toks.input_ids,
79
- attention_mask=toks.attention_mask,
80
- output_hidden_states=True
81
- )
82
-
83
- # 1. Берем -2 слой [Batch, Seq, Dim]
84
- hidden = outputs.hidden_states[-2]
85
-
86
- # 2. Достаем pooled вектор (последний токен) [Batch, Dim]
87
- seq_lens = toks.attention_mask.sum(dim=1) - 1
88
- pooled = hidden[torch.arange(hidden.shape[0]), seq_lens.clamp(min=0)]
89
-
90
- # 3. Нормализация
91
- norm = self.text_encoder.text_model.final_layer_norm
92
- hidden = norm(hidden)
93
- pooled = norm(pooled)
94
-
95
- # 4. Объединяем в матрицу: Пулед (как 1-й токен) + остальные токены
96
- # pooled.unsqueeze(1) делает [Batch, 1, Dim]
97
- embeds = torch.cat([pooled.unsqueeze(1), hidden], dim=1)
98
-
99
- # 5. Расширяем маску для нового токена (добавляем единицы спереди)
100
- ones = torch.ones((toks.attention_mask.shape[0], 1), dtype=toks.attention_mask.dtype, device=device)
101
- mask = torch.cat([ones, toks.attention_mask], dim=1)
102
-
103
- return embeds, mask, pooled
104
-
105
- def get_pooled_encode(texts):
106
- if texts is None:
107
- texts = ""
108
-
109
- if isinstance(texts, str):
110
- texts = [texts]
111
-
112
- with torch.no_grad():
113
- # 1. Собираем текстовые промпты оборачивая их в Chat Template
114
- formatted_prompts = []
115
- for t in texts:
116
- messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
117
- res_text = self.tokenizer2.apply_chat_template(
118
- messages,
119
- add_generation_prompt=True,
120
- tokenize=False
121
- )
122
- formatted_prompts.append(res_text)
123
-
124
- # 2. Токенизируем, режем и добавляем паддинг за один раз
125
- toks = self.tokenizer2(
126
- formatted_prompts,
127
- padding="max_length",
128
- max_length=self.text_encoder.config.max_position_embeddings,
129
- truncation=True, # Не забываем обрезать, если вдруг длиннее
130
- return_tensors="pt"
131
- ).to(device)
132
-
133
- # 3. Прогоняем через модель
134
- outputs = self.text_encoder2(
135
- input_ids=toks.input_ids,
136
- attention_mask=toks.attention_mask,
137
- output_hidden_states=True
138
- )
139
-
140
- layer_index = -2
141
- last_hidden = outputs.hidden_states[layer_index]
142
- seq_len = toks.attention_mask.sum(dim=1) - 1
143
- pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
144
-
145
- return pooled
146
  def get_encode(texts):
147
  if texts is None:
148
  texts = ""
@@ -155,7 +69,7 @@ class SdxsPipeline(DiffusionPipeline):
155
  formatted_prompts = []
156
  for t in texts:
157
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
158
- res_text = self.tokenizer2.apply_chat_template(
159
  messages,
160
  add_generation_prompt=True,
161
  tokenize=False
@@ -163,16 +77,16 @@ class SdxsPipeline(DiffusionPipeline):
163
  formatted_prompts.append(res_text)
164
 
165
  # 2. Токенизируем, режем и добавляем паддинг за один раз
166
- toks = self.tokenizer2(
167
  formatted_prompts,
168
  padding="max_length",
169
- max_length=self.text_encoder.config.max_position_embeddings,
170
  truncation=True, # Не забываем обрезать, если вдруг длиннее
171
  return_tensors="pt"
172
  ).to(device)
173
 
174
  # 3. Прогоняем через модель
175
- outputs = self.text_encoder2(
176
  input_ids=toks.input_ids,
177
  attention_mask=toks.attention_mask,
178
  output_hidden_states=True
@@ -185,11 +99,6 @@ class SdxsPipeline(DiffusionPipeline):
185
 
186
  return last_hidden, toks.attention_mask, pooled
187
 
188
- #pos_embeds, pos_mask, pooled_pos = get_single_encode(prompt)
189
- #neg_embeds, neg_mask, pooled_neg = get_single_encode(negative_prompt)
190
- # 768 + 2048
191
- #pos_pooled = get_pooled_encode(prompt) #torch.cat([pooled_pos, get_pooled_encode(prompt)], dim=1)
192
- #neg_pooled = get_pooled_encode(negative_prompt) #torch.cat([pooled_neg, get_pooled_encode(negative_prompt)], dim=1)
193
  pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
194
  neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
195
 
@@ -223,7 +132,7 @@ class SdxsPipeline(DiffusionPipeline):
223
  seed: Optional[int] = None,
224
  output_type: str = "pil",
225
  return_dict: bool = True,
226
- refine_prompt: bool = False, # Флаг рефайна!
227
  # structure_preservation оставляем для совместимости, но теперь он почти не нужен
228
  structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
229
  **kwargs,
 
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
10
+ from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Tokenizer
11
 
12
  @dataclass
13
  class SdxsPipelineOutput(BaseOutput):
14
  images: Union[List[Image.Image], np.ndarray]
15
+ prompt: Optional[Union[str, List[str]]] = None
16
 
17
  class SdxsPipeline(DiffusionPipeline):
18
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
19
  super().__init__()
20
  self.register_modules(
21
  vae=vae,
22
  text_encoder=text_encoder,
 
23
  tokenizer=tokenizer,
 
24
  unet=unet,
25
  scheduler=scheduler
26
  )
 
57
 
58
 
59
  def encode_prompt(self, prompt, negative_prompt, device, dtype):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def get_encode(texts):
61
  if texts is None:
62
  texts = ""
 
69
  formatted_prompts = []
70
  for t in texts:
71
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
72
+ res_text = self.tokenizer.apply_chat_template(
73
  messages,
74
  add_generation_prompt=True,
75
  tokenize=False
 
77
  formatted_prompts.append(res_text)
78
 
79
  # 2. Токенизируем, режем и добавляем паддинг за один раз
80
+ toks = self.tokenizer(
81
  formatted_prompts,
82
  padding="max_length",
83
+ max_length=255,
84
  truncation=True, # Не забываем обрезать, если вдруг длиннее
85
  return_tensors="pt"
86
  ).to(device)
87
 
88
  # 3. Прогоняем через модель
89
+ outputs = self.text_encoder(
90
  input_ids=toks.input_ids,
91
  attention_mask=toks.attention_mask,
92
  output_hidden_states=True
 
99
 
100
  return last_hidden, toks.attention_mask, pooled
101
 
 
 
 
 
 
102
  pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
103
  neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
104
 
 
132
  seed: Optional[int] = None,
133
  output_type: str = "pil",
134
  return_dict: bool = True,
135
+ refine_prompt: bool = False,
136
  # structure_preservation оставляем для совместимости, но теперь он почти не нужен
137
  structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
138
  **kwargs,
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae07d83aedfac9040cb0169a0aa14f1b2185254882f7862b78e52aabd0a95821
3
- size 5492219
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b1dc22c65c672008be6159c53028670acace729df6e26c3d1199a8929a07060
3
+ size 6130032
text_encoder/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c06cbeeddf5d93f5c7abc16b17a251b2a9ba6a6f08d7114fd8a269efeab1975
3
- size 563
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:199bacf59248a05c934c618cacd62b6cc2f60e1637563f037206c2b09330a4ff
3
+ size 2613
{text_encoder2 → text_encoder}/generation_config.json RENAMED
File without changes
text_encoder/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b38361773a76fe3d7e717ed36206b112be1aa110567a465e377896114394fdb
3
- size 246406816
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:020f49a22fe87e482485f27e57a01782f1faf7d0f312cc740d83faac36babcba
3
+ size 4426558248
text_encoder2/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:199bacf59248a05c934c618cacd62b6cc2f60e1637563f037206c2b09330a4ff
3
- size 2613
 
 
 
 
text_encoder2/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:020f49a22fe87e482485f27e57a01782f1faf7d0f312cc740d83faac36babcba
3
- size 4426558248
 
 
 
 
tmp/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d9d2485cfbdfe8c4a3d350907422ade854ffa1c8b06bd39b896ff45fc444cda
3
- size 1858
 
 
 
 
tmp/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ef8fd48bd3075eae7202915c3b05b3ed8e90c70c41625f3f70a2d868a1013c9
3
- size 3159550424
 
 
 
 
{tokenizer2 → tokenizer}/chat_template.jinja RENAMED
File without changes
tokenizer/merges.txt CHANGED
The diff for this file is too large to render. See raw diff
 
{tokenizer2 → tokenizer}/preprocessor_config.json RENAMED
File without changes
{tokenizer2 → tokenizer}/processor_config.json RENAMED
File without changes
tokenizer/special_tokens_map.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2cdb3b8331a60c92fc1e55a13e9fd61fd2293c5a51275fdcccd62b780052530e
3
- size 588
 
 
 
 
{tokenizer2 → tokenizer}/tokenizer.json RENAMED
File without changes
tokenizer/tokenizer_config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ae623b1013846a4edb4e2206a09ea3f7a4a92b3215f6f840f3706ccfedcef2d
3
- size 737
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4f52c2103a685c3d9f3022b3954246fbab865abd709ac69d8b8a98f79580564
3
+ size 1140
tokenizer/vocab.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e089ad92ba36837a0d31433e555c8f45fe601ab5c221d4f607ded32d9f7a4349
3
- size 1059962
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce99b4cb2983d118806ce0a8b777a35b093e2000a503ebde25853284c9dfa003
3
+ size 6722759
tokenizer2/merges.txt DELETED
The diff for this file is too large to render. See raw diff
 
tokenizer2/tokenizer_config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4f52c2103a685c3d9f3022b3954246fbab865abd709ac69d8b8a98f79580564
3
- size 1140
 
 
 
 
tokenizer2/vocab.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce99b4cb2983d118806ce0a8b777a35b093e2000a503ebde25853284c9dfa003
3
- size 6722759
 
 
 
 
train.py CHANGED
@@ -149,10 +149,8 @@ if accelerator.is_main_process:
149
  # --------------------------- Загрузка моделей ---------------------------
150
  #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
151
  vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
152
- tokenizer = CLIPTokenizer.from_pretrained("tokenizer")
153
- text_encoder = CLIPTextModel.from_pretrained("text_encoder", torch_dtype=torch.float16).to(device).eval()
154
- tokenizer2 = Qwen3_5Tokenizer.from_pretrained("tokenizer2")
155
- text_encoder2 = Qwen3_5ForConditionalGeneration.from_pretrained("text_encoder2", torch_dtype=torch.float16).to(device).eval()
156
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
157
 
158
  def encode_texts(texts, max_length=max_length):
@@ -168,7 +166,7 @@ def encode_texts(texts, max_length=max_length):
168
  formatted_prompts = []
169
  for t in texts:
170
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
171
- res_text = tokenizer2.apply_chat_template(
172
  messages,
173
  add_generation_prompt=True,
174
  tokenize=False
@@ -176,7 +174,7 @@ def encode_texts(texts, max_length=max_length):
176
  formatted_prompts.append(res_text)
177
 
178
  # 2. Токенизируем, режем и добавляем паддинг за один раз
179
- toks = tokenizer2(
180
  formatted_prompts,
181
  padding="max_length",
182
  max_length=max_length,
@@ -185,7 +183,7 @@ def encode_texts(texts, max_length=max_length):
185
  ).to(device)
186
 
187
  # 3. Прогоняем через модель
188
- outputs = text_encoder2(
189
  input_ids=toks.input_ids,
190
  attention_mask=toks.attention_mask,
191
  output_hidden_states=True
 
149
  # --------------------------- Загрузка моделей ---------------------------
150
  #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
151
  vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
152
+ tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer2")
153
+ text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained("text_encoder2", torch_dtype=torch.float16).to(device).eval()
 
 
154
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
155
 
156
  def encode_texts(texts, max_length=max_length):
 
166
  formatted_prompts = []
167
  for t in texts:
168
  messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
169
+ res_text = tokenizer.apply_chat_template(
170
  messages,
171
  add_generation_prompt=True,
172
  tokenize=False
 
174
  formatted_prompts.append(res_text)
175
 
176
  # 2. Токенизируем, режем и добавляем паддинг за один раз
177
+ toks = tokenizer(
178
  formatted_prompts,
179
  padding="max_length",
180
  max_length=max_length,
 
183
  ).to(device)
184
 
185
  # 3. Прогоняем через модель
186
+ outputs = text_encoder(
187
  input_ids=toks.input_ids,
188
  attention_mask=toks.attention_mask,
189
  output_hidden_states=True
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:57210292c170ad4efb88279bd06e8aa3c12d4c9a8582c5f4f1b90f7285d4ab59
3
  size 6318956752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1a8e5563a08991becd74bb41292f800b383fb757cafe97ce5b87b6004dfb87a
3
  size 6318956752
unet_old/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e0c1f603d4dfb8d759010daf7d83df29ff148ba48693400fb14c0f4dbd9b7d2f
3
- size 1884
 
 
 
 
unet_old/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fc036a58605f5073498525bf4f8af44552ad1f57fcb7f545863d45bf95d5ce2
3
- size 5960474736