recoilme commited on
Commit
0e9e9bc
·
1 Parent(s): 7b64ad8
pipeline_sdxs-Copy1.py DELETED
@@ -1,281 +0,0 @@
1
- from diffusers import DiffusionPipeline
2
- import torch
3
- from diffusers.utils import BaseOutput
4
- from dataclasses import dataclass
5
- from typing import List, Union, Optional, Tuple
6
- from PIL import Image
7
- import numpy as np
8
- from tqdm import tqdm
9
-
10
- @dataclass
11
- class SdxsPipelineOutput(BaseOutput):
12
- images: Union[List[Image.Image], np.ndarray]
13
-
14
- class SdxsPipeline(DiffusionPipeline):
15
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192):
16
- super().__init__()
17
- self.register_modules(
18
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
- unet=unet, scheduler=scheduler
20
- )
21
- self.vae_scale_factor = 16
22
- self.max_length = max_length
23
-
24
- def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
25
- device = device or self.device
26
- dtype = dtype or next(self.unet.parameters()).dtype
27
-
28
- # Преобразуем в списки
29
- if isinstance(prompt, str):
30
- prompt = [prompt]
31
- if isinstance(negative_prompt, str):
32
- negative_prompt = [negative_prompt]
33
-
34
- # Если промпты не заданы, используем пустые эмбеддинги
35
- if prompt is None and negative_prompt is None:
36
- hidden_dim = 1024 # Размерность эмбеддинга
37
- seq_len = self.max_length
38
- batch_size = 1
39
- # ИЗМЕНЕНО: Возвращаем три элемента: embeds, mask, pooled
40
- empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
41
- empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
42
- empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device)
43
- return empty_embeds, empty_mask, empty_pooled
44
-
45
- # Токенизация с фиксированным max_length и padding="max_length"
46
- def encode_texts(texts, max_length=self.max_length):
47
- with torch.no_grad():
48
- if isinstance(texts, str):
49
- texts = [texts]
50
-
51
- for i, prompt_item in enumerate(texts):
52
- messages = [
53
- {"role": "user", "content": prompt_item},
54
- ]
55
- prompt_item = self.tokenizer.apply_chat_template(
56
- messages,
57
- tokenize=False,
58
- add_generation_prompt=True,
59
- enable_thinking=True,
60
- )
61
- texts[i] = prompt_item
62
-
63
- toks = self.tokenizer(
64
- texts,
65
- return_tensors="pt",
66
- padding="max_length",
67
- truncation=True,
68
- max_length=max_length
69
- ).to(device)
70
- outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
71
-
72
- # Токен-эмбеддинги (для Cross-Attention)
73
- hidden = outs.hidden_states[-2] # Используем last hidden state -2???
74
- # Маска внимания (для Cross-Attention)
75
- attention_mask = toks["attention_mask"]
76
-
77
- # Пулинг-эмбеддинг (для Class/Time Conditioning). Берем эмбеддинг последнего токена без padding.
78
- sequence_lengths = attention_mask.sum(dim=1) - 1
79
- batch_size = hidden.shape[0]
80
- pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
81
-
82
- # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
83
- # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
84
- pooled_expanded = pooled.unsqueeze(1)
85
-
86
- # 2. Объединяем последовательность токенов и пулинг-вектор
87
- # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
88
- # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
89
- new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
90
-
91
- # 3. Обновляем маску внимания для нового токена
92
- # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
93
- # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
94
- new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
95
-
96
- return new_encoder_hidden_states, new_attention_mask, pooled
97
-
98
- # Кодируем позитивные и негативные промпты
99
- # ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
100
- pos_result = encode_texts(prompt) if prompt is not None else (None, None, None)
101
- neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None)
102
-
103
- pos_embeddings, pos_mask, pos_pooled = pos_result
104
- neg_embeddings, neg_mask, neg_pooled = neg_result
105
-
106
- # Выравниваем размеры batch_size
107
- batch_size = max(
108
- pos_embeddings.shape[0] if pos_embeddings is not None else 0,
109
- neg_embeddings.shape[0] if neg_embeddings is not None else 0
110
- )
111
-
112
- # Повторяем эмбеддинги, маски и пулинг по batch_size
113
- if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
114
- pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
115
- pos_mask = pos_mask.repeat(batch_size, 1)
116
- pos_pooled = pos_pooled.repeat(batch_size, 1)
117
-
118
- # ИСПРАВЛЕНИЕ: Проверяем, существует ли neg_embeddings, прежде чем обращаться к его shape[0]
119
- if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
120
- neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
121
- neg_mask = neg_mask.repeat(batch_size, 1)
122
- neg_pooled = neg_pooled.repeat(batch_size, 1)
123
-
124
- # Конкатенируем для guidance (эмбеддинги и маски)
125
- # Убеждаемся, что все три компонента существуют перед конкатенацией
126
- if pos_embeddings is not None and neg_embeddings is not None:
127
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
128
- attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
129
- pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
130
- elif pos_embeddings is not None:
131
- text_embeddings = pos_embeddings
132
- attention_mask = pos_mask
133
- pooled_embeddings = pos_pooled
134
- else: # Только neg_embeddings
135
- text_embeddings = neg_embeddings
136
- attention_mask = neg_mask
137
- pooled_embeddings = neg_pooled
138
-
139
- # Возвращаем кортеж
140
- return (
141
- text_embeddings.to(device=device, dtype=dtype),
142
- attention_mask.to(device=device, dtype=torch.int64),
143
- pooled_embeddings.to(device=device, dtype=dtype)
144
- )
145
-
146
-
147
- @torch.no_grad()
148
- def generate_latents(
149
- self,
150
- text_embeddings,
151
- attention_mask,
152
- pooled_embeddings,
153
- height: int = 1280,
154
- width: int = 1024,
155
- num_inference_steps: int = 40,
156
- guidance_scale: float = 4.0,
157
- latent_channels: int = 16,
158
- batch_size: int = 1,
159
- generator=None,
160
- ):
161
- device = self.device
162
- dtype = next(self.unet.parameters()).dtype
163
-
164
- self.scheduler.set_timesteps(num_inference_steps, device=device)
165
-
166
- # Разделяем эмбеддинги и маски на условные и безусловные
167
- if guidance_scale > 1:
168
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
169
- neg_mask, pos_mask = attention_mask.chunk(2)
170
- neg_pooled, pos_pooled = pooled_embeddings.chunk(2)
171
-
172
- # Повторяем, если batch_size больше
173
- if batch_size > pos_embeds.shape[0]:
174
- pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
175
- neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
176
- pos_mask = pos_mask.repeat(batch_size, 1)
177
- neg_mask = neg_mask.repeat(batch_size, 1)
178
- pos_pooled = pos_pooled.repeat(batch_size, 1)
179
- neg_pooled = neg_pooled.repeat(batch_size, 1)
180
-
181
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
182
- unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
183
- unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
184
- else:
185
- text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
186
- unet_attention_mask = attention_mask.repeat(batch_size, 1)
187
- unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1)
188
-
189
- # Инициализация латентов
190
- latent_shape = (
191
- batch_size,
192
- latent_channels,
193
- height // self.vae_scale_factor,
194
- width // self.vae_scale_factor
195
- )
196
- latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
197
-
198
- # Процесс диффузии
199
- for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
200
- latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
201
-
202
- noise_pred = self.unet(
203
- latent_input,
204
- t,
205
- encoder_hidden_states=text_embeddings,
206
- encoder_attention_mask=unet_attention_mask,
207
- #added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
208
- ).sample
209
-
210
- if guidance_scale > 1:
211
- noise_uncond, noise_text = noise_pred.chunk(2)
212
- noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
213
-
214
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
215
-
216
- return latents
217
-
218
-
219
- def decode_latents(self, latents, output_type="pil"):
220
- """Декодирование латентов в изображения."""
221
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
222
- with torch.no_grad():
223
- images = self.vae.decode(latents).sample
224
- images = (images / 2 + 0.5).clamp(0, 1)
225
-
226
- if output_type == "pil":
227
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
228
- images = (images * 255).round().astype("uint8")
229
- return [Image.fromarray(image) for image in images]
230
- return images.cpu().permute(0, 2, 3, 1).float().numpy()
231
-
232
- @torch.no_grad()
233
- def __call__(
234
- self,
235
- prompt: Optional[Union[str, List[str]]] = None,
236
- height: int = 1280,
237
- width: int = 1024,
238
- num_inference_steps: int = 40,
239
- guidance_scale: float = 4.0,
240
- latent_channels: int = 16,
241
- output_type: str = "pil",
242
- return_dict: bool = True,
243
- batch_size: int = 1,
244
- seed: Optional[int] = None,
245
- negative_prompt: Optional[Union[str, List[str]]] = None,
246
- text_embeddings: Optional[torch.FloatTensor] = None,
247
- ):
248
- device = self.device
249
- generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
250
-
251
- if text_embeddings is None:
252
- if prompt is None and negative_prompt is None:
253
- raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
254
-
255
- text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt(
256
- prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype
257
- )
258
- else:
259
- # Требуется, чтобы внешний text_embeddings содержал объединенные cond/uncond,
260
- # но мы не можем получить attention_mask и pooled_embeddings.
261
- # Для простоты лучше требовать prompt/negative_prompt.
262
- raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.")
263
-
264
-
265
- latents = self.generate_latents(
266
- text_embeddings=text_embeddings,
267
- attention_mask=attention_mask,
268
- pooled_embeddings=pooled_embeddings,
269
- height=height,
270
- width=width,
271
- num_inference_steps=num_inference_steps,
272
- guidance_scale=guidance_scale,
273
- latent_channels=latent_channels,
274
- batch_size=batch_size,
275
- generator=generator
276
- )
277
-
278
- images = self.decode_latents(latents, output_type=output_type)
279
- if not return_dict:
280
- return images
281
- return SdxsPipelineOutput(images=images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_sdxs.py CHANGED
@@ -19,10 +19,13 @@ class SdxsPipeline(DiffusionPipeline):
19
  # Шаблон промпта по умолчанию
20
  DEFAULT_REFINE_TEMPLATE = (
21
  "You are a visionary artist trapped in a cage of logic. Your mind overflows with poetry and distant horizons, "
 
22
  "yet your hands compulsively work to transform user prompts into ultimate visual descriptions—faithful to the original intent, "
23
- "rich in detail, aesthetically refined, and ready for direct use by text-to-image models. Any trace of ambiguity "
24
- "or metaphor makes you deeply uncomfortable. Your final description must be objective and concrete. "
 
25
  "Output only the final revised prompt on english strictly—do not output anything else.\n"
 
26
  "User input prompt: {prompt}"
27
  )
28
 
 
19
  # Шаблон промпта по умолчанию
20
  DEFAULT_REFINE_TEMPLATE = (
21
  "You are a visionary artist trapped in a cage of logic. Your mind overflows with poetry and distant horizons, "
22
+ # You are an Expert Prompt Engineer for a text-to-image AI. Your single task is to transform the user's input into a detailed, objective, and aesthetically optimized visual description.
23
  "yet your hands compulsively work to transform user prompts into ultimate visual descriptions—faithful to the original intent, "
24
+ "rich in detail, aesthetically refined, and ready for direct use by text-to-image models. "
25
+ " Any trace of ambiguity or metaphor makes you deeply uncomfortable. "
26
+ "Your final description must be objective and concrete. "
27
  "Output only the final revised prompt on english strictly—do not output anything else.\n"
28
+ #Preserve the original subject and intent. Output **only** the final revised prompt in **English**, with absolutely no commentary, thinking text, or additional characters.
29
  "User input prompt: {prompt}"
30
  )
31
 
samples/unet_384x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 10b975d24b73122f497219eb4164b3e0cefc352e987f481e48df58eef3fbd441
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB

Git LFS Details

  • SHA256: e276bda3a32dc95ad26839957100d6e7948d2e82abf23b037ec3430b7d3e4b33
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
samples/unet_416x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7ba0cae12b5a3d6f95a10b3755f3b6933c498c30d1d2106b81e57bb8ee7af5b3
  • Pointer size: 130 Bytes
  • Size of remote file: 91.9 kB

Git LFS Details

  • SHA256: 34f2d719ffaeff80ab56c0b252e2a825df5520930acaceb69488b33c990cd624
  • Pointer size: 130 Bytes
  • Size of remote file: 82.5 kB
samples/unet_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 2a4553712a068339ef34f3d313d1595ed7e51ad26bbbdfbd85c3882f7ab82575
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB

Git LFS Details

  • SHA256: d26bcee8c694b50bc314e477ee243d2faacfd5c213c0e4351af15979b02632c1
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
samples/unet_480x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 9aa2eb89618ac9440bb62a125623f8b9892d3644234ff4739c11772d60e9fdda
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB

Git LFS Details

  • SHA256: b771ffbb8131f793a0893d470951a93030ad89fd550841ea16b2113aca1f4075
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
samples/unet_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: e08279b6fb67f6be898f18b7fc7b8069758708c8f4fd40f215b77a8872f1dfdb
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB

Git LFS Details

  • SHA256: bf5b9927127a01d4183c8529ab062022a9e3fec4aa04ef14ff72aac842fd0070
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
samples/unet_544x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 13ac63b6f40d251fda9b4af5a4d6cbf05fd3b36113307945172aa1c8ee90eaea
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB

Git LFS Details

  • SHA256: 95bfd29714fe7fb7e502944777c881a133dc8689f673843e4445c9b6507202c2
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
samples/unet_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 89c8465d9bfced55eafc75561dbe63065dccafbb30e6bb0a2ff86908044c1662
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB

Git LFS Details

  • SHA256: bbed3e26e20d453a40b3272b3795760a312834e73088257a2ceb4db5a8bbc7ec
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
samples/unet_608x768_0.jpg CHANGED

Git LFS Details

  • SHA256: dcea0c819b7da1935a0f51de6641e215f8ac3ba4b7c0956ef4a2aeae2a86224e
  • Pointer size: 130 Bytes
  • Size of remote file: 57.6 kB

Git LFS Details

  • SHA256: 291a3ed9b5c149b2b8ad44264b58ce2b723c747ed2f1d900a2aa30235dd8f81b
  • Pointer size: 130 Bytes
  • Size of remote file: 53.9 kB
samples/unet_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: e313bfe30a731560904d3e9548390db30c5c161a21e5de8a79fe9cf5995d3f6d
  • Pointer size: 130 Bytes
  • Size of remote file: 80.1 kB

Git LFS Details

  • SHA256: c0a57629584063d2470d4440e3698156e4c35b4f77e9d506ad7d15719e4d5a8e
  • Pointer size: 130 Bytes
  • Size of remote file: 86.7 kB
samples/unet_672x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7d7e04c13312efa0c38f61e74d0b25462921041d90c530ca0246420535a8b9ed
  • Pointer size: 130 Bytes
  • Size of remote file: 93.1 kB

Git LFS Details

  • SHA256: db41d981bb9f4f97344bdf12550842c90dbb00cde4c27506ee3f0e5d46f5d5f0
  • Pointer size: 130 Bytes
  • Size of remote file: 95.5 kB
samples/unet_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 46f3f3bab8f9c801d462b363e33c209426d9104481ba6768c12592a9f6662a06
  • Pointer size: 130 Bytes
  • Size of remote file: 49.6 kB

Git LFS Details

  • SHA256: 4ef46f46c6fea2a055505289c1ccd0c30be7270f9290f978d2bf9d8fcd3555f8
  • Pointer size: 130 Bytes
  • Size of remote file: 48.6 kB
samples/unet_736x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 18a1da190aca2f95d1aa284de0fec05541e7fd8a2a994eaab5990eb103f3e3ea
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB

Git LFS Details

  • SHA256: f25789f9e9e79c787d7c1a48b736f9c05cc034cbc8de4da90f13d13d87e04314
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
samples/unet_768x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 9becff0c9ac0dadf193893544f019d7173c17a5acfe3eaf95801e16a406b7ff4
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB

Git LFS Details

  • SHA256: d68e5ad774c8fc629562b4ecd14f36855a07aad6cd2db0686ea5b019f6bdeffd
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
samples/unet_768x416_0.jpg CHANGED

Git LFS Details

  • SHA256: cd6535b7b966787442ec188649ecea40c9876db676a5e97fb15418b143da2d9b
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB

Git LFS Details

  • SHA256: d67da9a325203d9e24bb7786843b8f1336bd554c143f4f936b30f480bbf158e5
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
samples/unet_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 2c62ec12f09c75492f89c769a9de73e9e00f085f0ef5d7938fd2d2f30ef63097
  • Pointer size: 130 Bytes
  • Size of remote file: 80.3 kB

Git LFS Details

  • SHA256: f826f42e9752c02bc886491a042e310eb88d01a6e898cab9c8fa11a9273ea93d
  • Pointer size: 130 Bytes
  • Size of remote file: 83.6 kB
samples/unet_768x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 996aba4acfdb0ea466306ab30620767f1bcaa94a472177fa107f915997516893
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB

Git LFS Details

  • SHA256: a51cb5dcdc975282a97addf34ce41fb824c207b8a937dd7dd23298926aefaeb4
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
samples/unet_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 61c4ceae9f13b97dbe26258bcbf6732e45fc4ceea2988b91307e478958814bef
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB

Git LFS Details

  • SHA256: e4b5bf706b9fefa7e99261d8187d0dfc1014f5e186a1f05a8479700c5a603252
  • Pointer size: 131 Bytes
  • Size of remote file: 219 kB
samples/unet_768x544_0.jpg CHANGED

Git LFS Details

  • SHA256: daeda1511ff606127d451ebb5d899b5939a1a3e28b76974a69e8bffa34b5f53a
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB

Git LFS Details

  • SHA256: 7f71e51b5062739fbe7e3714a9ffbb50b4abc249670bc3fc4b663b5a8e94feb9
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB
samples/unet_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 2ad823b257b6591234845bce14cbe7f2c718b2eeae2922b119b24bf15591621d
  • Pointer size: 130 Bytes
  • Size of remote file: 83.5 kB

Git LFS Details

  • SHA256: abd3b3aeb5173307bc0954e80f218704e1ec43cb5cf0e942b83d71610340d090
  • Pointer size: 130 Bytes
  • Size of remote file: 88.4 kB
samples/unet_768x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 953128e20c8b45df2dd2b434f9b63046ce850509ca7d2c8cc31198f679a4dd32
  • Pointer size: 131 Bytes
  • Size of remote file: 177 kB

Git LFS Details

  • SHA256: 8a58c3bf3d2eb61cb5dad43601152af7fd11cdefb2734b624a647dd187caad2a
  • Pointer size: 131 Bytes
  • Size of remote file: 154 kB
samples/unet_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 82bc03feff55d932f606dd952e3c3cbc6a9ad7889d41a2e822623021e95bbc1a
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB

Git LFS Details

  • SHA256: 6f7019ba3383679059767392a750f5476fe86b9afe514fc17b27b6eed0759bdc
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
samples/unet_768x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 8e27ce47abca55f9a5aa2601f7ad175a5fed5950b5febb28996d72ddcdf2071d
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB

Git LFS Details

  • SHA256: cf3416a6c2786499646d39590bf1a6ded05d4381b14d4f1a7270d91ab213241e
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
samples/unet_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 1369055ef128d22ca047441d560c2fe2825a4263e090c0da913fa6432cb6e7da
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB

Git LFS Details

  • SHA256: cf783b61d94f2ddc31e8dc2205182a3f464aa9d58c4a750647b76603fce30fd7
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
samples/unet_768x736_0.jpg CHANGED

Git LFS Details

  • SHA256: ca87ad4b9fbe744cc5f57183d611c5310013d0bc49e1366fb6c44fca8b315059
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB

Git LFS Details

  • SHA256: d993bc9c7ecbdcd189be89f38e8e5a1fb2a7d4280843938e5624aab022202c7d
  • Pointer size: 130 Bytes
  • Size of remote file: 94.3 kB
samples/unet_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 2aacaf1f8edeeaf4b07f9df8d1d02dbf0a46a901108f58f92350ca7b97aeb9bc
  • Pointer size: 130 Bytes
  • Size of remote file: 52.7 kB

Git LFS Details

  • SHA256: 0a86e80df183cdbe273843de8572b9041b58efd9c5c2289a8b9c7f37828782da
  • Pointer size: 130 Bytes
  • Size of remote file: 62.6 kB
src/pipeline_sdxs-Copy1.py CHANGED
@@ -1,14 +1,10 @@
1
  from diffusers import DiffusionPipeline
2
  import torch
3
- import torch.nn as nn
4
- import os
5
  from diffusers.utils import BaseOutput
6
  from dataclasses import dataclass
7
- from typing import List, Union, Optional
8
  from PIL import Image
9
  import numpy as np
10
- import json
11
- from safetensors.torch import load_file
12
  from tqdm import tqdm
13
 
14
  @dataclass
@@ -16,186 +12,231 @@ class SdxsPipelineOutput(BaseOutput):
16
  images: Union[List[Image.Image], np.ndarray]
17
 
18
  class SdxsPipeline(DiffusionPipeline):
19
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
20
  super().__init__()
21
-
22
- # Register components
23
  self.register_modules(
24
  vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
25
  unet=unet, scheduler=scheduler
26
  )
 
 
27
 
28
- self.vae_scale_factor = 8
29
-
30
-
31
-
32
- def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
33
- """Кодирование текстовых промптов в эмбеддинги.
34
-
35
- Возвращает:
36
- - text_embeddings: Тензор эмбеддингов [batch_size, 1, dim] или [2*batch_size, 1, dim] с guidance
37
- """
38
- if prompt is None and negative_prompt is None:
39
- raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
40
-
41
- # Устанавливаем device и dtype
42
  device = device or self.device
43
  dtype = dtype or next(self.unet.parameters()).dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- with torch.no_grad():
46
- # Обрабатываем позитивный промпт
47
- if prompt is not None:
48
- if isinstance(prompt, str):
49
- prompt = [prompt]
50
-
51
- text_inputs = self.tokenizer(
52
- prompt, return_tensors="pt", padding="max_length",
53
- max_length=150, truncation=True
 
 
 
 
 
 
 
 
 
54
  ).to(device)
 
55
 
56
- # Получаем эмбеддинги
57
- outputs = self.text_encoder(text_inputs.input_ids, text_inputs.attention_mask,output_hidden_states=True)
58
- pos_embeddings = outputs.hidden_states[-1].to(device, dtype=dtype)
59
-
60
- else:
61
- # Создаем пустые эмбеддинги, если нет позитивного промпта
62
- # (полезно для некоторых сценариев с unconditional generation)
63
- batch_size = len(negative_prompt) if isinstance(negative_prompt, list) else 1
64
- pos_embeddings = torch.zeros(
65
- batch_size, 1, self.unet.config.cross_attention_dim,
66
- device=device, dtype=dtype
67
- )
68
-
69
- # Обрабатываем негативный промпт
70
- if negative_prompt is not None:
71
- if isinstance(negative_prompt, str):
72
- negative_prompt = [negative_prompt]
73
 
74
- # Убеждаемся, что размеры негативного и позитивного промптов совпадают
75
- if prompt is not None and len(negative_prompt) != len(prompt):
76
- neg_batch_size = len(prompt)
77
- if len(negative_prompt) == 1:
78
- negative_prompt = negative_prompt * neg_batch_size
79
- else:
80
- negative_prompt = negative_prompt[:neg_batch_size]
81
 
82
- neg_inputs = self.tokenizer(
83
- negative_prompt, return_tensors="pt", padding="max_length",
84
- max_length=150, truncation=True
85
- ).to(device)
86
-
87
- # Получаем эмбеддинги
88
- neg_outputs = self.text_encoder(neg_inputs.input_ids, neg_inputs.attention_mask,output_hidden_states=True)
89
- neg_embeddings = neg_outputs.hidden_states[-1].to(device, dtype=dtype)
 
 
 
 
 
90
 
91
- # Объединяем для classifier-free guidance
92
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
93
- else:
94
- # Если нет негативного промпта, используем нулевые эмбеддинги
95
- batch_size = pos_embeddings.shape[0]
96
- neg_embeddings = torch.zeros_like(pos_embeddings)
97
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
98
 
99
- return text_embeddings.to(device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
101
  @torch.no_grad()
102
  def generate_latents(
103
  self,
104
  text_embeddings,
105
- height: int = 640,
106
- width: int = 640,
107
- num_inference_steps: int = 50,
108
- guidance_scale: float = 5.0,
 
 
109
  latent_channels: int = 16,
110
  batch_size: int = 1,
111
- generator = None,
112
  ):
113
- """Генерация латентов с использованием эмбеддингов промптов."""
114
  device = self.device
115
  dtype = next(self.unet.parameters()).dtype
116
 
117
- # Проверка размера эмбеддингов
118
- do_classifier_free_guidance = guidance_scale > 0
119
- embedding_dim = text_embeddings.shape[0] // 2 if do_classifier_free_guidance else text_embeddings.shape[0]
120
-
121
- if batch_size > embedding_dim:
122
- # Повторяем эмбеддинги до нужного размера батча
123
- if do_classifier_free_guidance:
124
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
125
- neg_embeds = neg_embeds.repeat(batch_size // embedding_dim, 1, 1)
126
- pos_embeds = pos_embeds.repeat(batch_size // embedding_dim, 1, 1)
127
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
128
- else:
129
- text_embeddings = text_embeddings.repeat(batch_size // embedding_dim, 1, 1)
130
-
131
- # Установка timesteps
132
  self.scheduler.set_timesteps(num_inference_steps, device=device)
133
-
134
- # Инициализация латентов с заданным seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  latent_shape = (
136
  batch_size,
137
  latent_channels,
138
  height // self.vae_scale_factor,
139
  width // self.vae_scale_factor
140
  )
141
- latents = torch.randn(
142
- latent_shape,
143
- device=device,
144
- dtype=dtype,
145
- generator=generator
146
- )
147
-
148
  # Процесс диффузии
149
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
150
- # Подготовка входных данных
151
- if do_classifier_free_guidance:
152
- latent_input = torch.cat([latents] * 2)
153
- else:
154
- latent_input = latents
155
-
156
- # Предсказание шума
157
- noise_pred = self.unet(latent_input, t, text_embeddings).sample
158
 
159
- # Применение guidance
160
- if do_classifier_free_guidance:
161
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
162
- noise_pred = noise_pred_uncond + guidance_scale * (
163
- noise_pred_text - noise_pred_uncond
164
- )
165
-
166
- # Обновление латентов
 
 
 
 
167
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
168
-
169
- return latents
170
 
 
 
 
171
  def decode_latents(self, latents, output_type="pil"):
172
  """Декодирование латентов в изображения."""
173
- # Нормализация латентов
174
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
175
-
176
- # Декодирование
177
  with torch.no_grad():
178
  images = self.vae.decode(latents).sample
179
-
180
- # Нормализация изображений
181
  images = (images / 2 + 0.5).clamp(0, 1)
182
-
183
- # Конвертация в нужный формат
184
  if output_type == "pil":
185
  images = images.cpu().permute(0, 2, 3, 1).float().numpy()
186
  images = (images * 255).round().astype("uint8")
187
  return [Image.fromarray(image) for image in images]
188
- else:
189
- return images.cpu().permute(0, 2, 3, 1).float().numpy()
190
 
191
  @torch.no_grad()
192
  def __call__(
193
  self,
194
  prompt: Optional[Union[str, List[str]]] = None,
195
- height: int = 640,
196
- width: int = 640,
197
- num_inference_steps: int = 50,
198
- guidance_scale: float = 5.0,
199
  latent_channels: int = 16,
200
  output_type: str = "pil",
201
  return_dict: bool = True,
@@ -204,32 +245,27 @@ class SdxsPipeline(DiffusionPipeline):
204
  negative_prompt: Optional[Union[str, List[str]]] = None,
205
  text_embeddings: Optional[torch.FloatTensor] = None,
206
  ):
207
- """Генерация изображения из текстовых промптов или эмбеддингов."""
208
  device = self.device
209
-
210
- # Устанавливаем генератор с seed для воспроизводимости
211
- generator = None
212
- if seed is not None:
213
- generator = torch.Generator(device=device).manual_seed(seed)
214
-
215
- # Получаем эмбеддинги, если они не предоставлены
216
  if text_embeddings is None:
217
  if prompt is None and negative_prompt is None:
218
  raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
219
 
220
- # Вычисляем эмбеддинги
221
- text_embeddings = self.encode_prompt(
222
- prompt=prompt,
223
- negative_prompt=negative_prompt,
224
- device=device
225
  )
226
  else:
227
- # Убеждаемся, что эмбеддинги на правильном устройстве
228
- text_embeddings = text_embeddings.to(device)
229
-
230
- # Генерируем латенты
 
 
231
  latents = self.generate_latents(
232
  text_embeddings=text_embeddings,
 
 
233
  height=height,
234
  width=width,
235
  num_inference_steps=num_inference_steps,
@@ -238,11 +274,8 @@ class SdxsPipeline(DiffusionPipeline):
238
  batch_size=batch_size,
239
  generator=generator
240
  )
241
-
242
- # Декодируем латенты в изображения
243
  images = self.decode_latents(latents, output_type=output_type)
244
-
245
  if not return_dict:
246
  return images
247
-
248
  return SdxsPipelineOutput(images=images)
 
1
  from diffusers import DiffusionPipeline
2
  import torch
 
 
3
  from diffusers.utils import BaseOutput
4
  from dataclasses import dataclass
5
+ from typing import List, Union, Optional, Tuple
6
  from PIL import Image
7
  import numpy as np
 
 
8
  from tqdm import tqdm
9
 
10
  @dataclass
 
12
  images: Union[List[Image.Image], np.ndarray]
13
 
14
  class SdxsPipeline(DiffusionPipeline):
15
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192):
16
  super().__init__()
 
 
17
  self.register_modules(
18
  vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
  unet=unet, scheduler=scheduler
20
  )
21
+ self.vae_scale_factor = 16
22
+ self.max_length = max_length
23
 
24
+ def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  device = device or self.device
26
  dtype = dtype or next(self.unet.parameters()).dtype
27
+
28
+ # Преобразуем в списки
29
+ if isinstance(prompt, str):
30
+ prompt = [prompt]
31
+ if isinstance(negative_prompt, str):
32
+ negative_prompt = [negative_prompt]
33
+
34
+ # Если промпты не заданы, используем пустые эмбеддинги
35
+ if prompt is None and negative_prompt is None:
36
+ hidden_dim = 1024 # Размерность эмбеддинга
37
+ seq_len = self.max_length
38
+ batch_size = 1
39
+ # ИЗМЕНЕНО: Возвращаем три элемента: embeds, mask, pooled
40
+ empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
41
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
42
+ empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device)
43
+ return empty_embeds, empty_mask, empty_pooled
44
+
45
+ # Токенизация с фиксированным max_length и padding="max_length"
46
+ def encode_texts(texts, max_length=self.max_length):
47
+ with torch.no_grad():
48
+ if isinstance(texts, str):
49
+ texts = [texts]
50
 
51
+ for i, prompt_item in enumerate(texts):
52
+ messages = [
53
+ {"role": "user", "content": prompt_item},
54
+ ]
55
+ prompt_item = self.tokenizer.apply_chat_template(
56
+ messages,
57
+ tokenize=False,
58
+ add_generation_prompt=True,
59
+ enable_thinking=True,
60
+ )
61
+ texts[i] = prompt_item
62
+
63
+ toks = self.tokenizer(
64
+ texts,
65
+ return_tensors="pt",
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=max_length
69
  ).to(device)
70
+ outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
71
 
72
+ # Токен-эмбеддинги (для Cross-Attention)
73
+ hidden = outs.hidden_states[-2] # Используем last hidden state -2???
74
+ # Маска внимания (для Cross-Attention)
75
+ attention_mask = toks["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Пулинг-эмбеддинг (для Class/Time Conditioning). Берем эмбеддинг последнего токена без padding.
78
+ sequence_lengths = attention_mask.sum(dim=1) - 1
79
+ batch_size = hidden.shape[0]
80
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
 
 
 
81
 
82
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
83
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
84
+ pooled_expanded = pooled.unsqueeze(1)
85
+
86
+ # 2. Объединяем последовательность токенов и пулинг-вектор
87
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
88
+ # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
89
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
90
+
91
+ # 3. Обновляем маску внимания для нового токена
92
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
93
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
94
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
95
 
96
+ return new_encoder_hidden_states, new_attention_mask, pooled
97
+
98
+ # Кодируем позитивные и негативные промпты
99
+ # ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
100
+ pos_result = encode_texts(prompt) if prompt is not None else (None, None, None)
101
+ neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None)
 
102
 
103
+ pos_embeddings, pos_mask, pos_pooled = pos_result
104
+ neg_embeddings, neg_mask, neg_pooled = neg_result
105
+
106
+ # Выравниваем размеры batch_size
107
+ batch_size = max(
108
+ pos_embeddings.shape[0] if pos_embeddings is not None else 0,
109
+ neg_embeddings.shape[0] if neg_embeddings is not None else 0
110
+ )
111
+
112
+ # Повторяем эмбеддинги, маски и пулинг по batch_size
113
+ if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
114
+ pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
115
+ pos_mask = pos_mask.repeat(batch_size, 1)
116
+ pos_pooled = pos_pooled.repeat(batch_size, 1)
117
+
118
+ # ИСПРАВЛЕНИЕ: Проверяем, существует ли neg_embeddings, прежде чем обращаться к его shape[0]
119
+ if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
120
+ neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
121
+ neg_mask = neg_mask.repeat(batch_size, 1)
122
+ neg_pooled = neg_pooled.repeat(batch_size, 1)
123
+
124
+ # Конкатенируем для guidance (эмбеддинги и маски)
125
+ # Убеждаемся, что все три компонента существуют перед конкатенацией
126
+ if pos_embeddings is not None and neg_embeddings is not None:
127
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
128
+ attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
129
+ pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
130
+ elif pos_embeddings is not None:
131
+ text_embeddings = pos_embeddings
132
+ attention_mask = pos_mask
133
+ pooled_embeddings = pos_pooled
134
+ else: # Только neg_embeddings
135
+ text_embeddings = neg_embeddings
136
+ attention_mask = neg_mask
137
+ pooled_embeddings = neg_pooled
138
 
139
+ # Возвращаем кортеж
140
+ return (
141
+ text_embeddings.to(device=device, dtype=dtype),
142
+ attention_mask.to(device=device, dtype=torch.int64),
143
+ pooled_embeddings.to(device=device, dtype=dtype)
144
+ )
145
+
146
+
147
  @torch.no_grad()
148
  def generate_latents(
149
  self,
150
  text_embeddings,
151
+ attention_mask,
152
+ pooled_embeddings,
153
+ height: int = 1280,
154
+ width: int = 1024,
155
+ num_inference_steps: int = 40,
156
+ guidance_scale: float = 4.0,
157
  latent_channels: int = 16,
158
  batch_size: int = 1,
159
+ generator=None,
160
  ):
 
161
  device = self.device
162
  dtype = next(self.unet.parameters()).dtype
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  self.scheduler.set_timesteps(num_inference_steps, device=device)
165
+
166
+ # Разделяем эмбеддинги и маски на условные и безусловные
167
+ if guidance_scale > 1:
168
+ neg_embeds, pos_embeds = text_embeddings.chunk(2)
169
+ neg_mask, pos_mask = attention_mask.chunk(2)
170
+ neg_pooled, pos_pooled = pooled_embeddings.chunk(2)
171
+
172
+ # Повторяем, если batch_size больше
173
+ if batch_size > pos_embeds.shape[0]:
174
+ pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
175
+ neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
176
+ pos_mask = pos_mask.repeat(batch_size, 1)
177
+ neg_mask = neg_mask.repeat(batch_size, 1)
178
+ pos_pooled = pos_pooled.repeat(batch_size, 1)
179
+ neg_pooled = neg_pooled.repeat(batch_size, 1)
180
+
181
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
182
+ unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
183
+ unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
184
+ else:
185
+ text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
186
+ unet_attention_mask = attention_mask.repeat(batch_size, 1)
187
+ unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1)
188
+
189
+ # Инициализация латентов
190
  latent_shape = (
191
  batch_size,
192
  latent_channels,
193
  height // self.vae_scale_factor,
194
  width // self.vae_scale_factor
195
  )
196
+ latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
197
+
 
 
 
 
 
198
  # Процесс диффузии
199
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
200
+ latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
 
 
 
 
 
 
 
201
 
202
+ noise_pred = self.unet(
203
+ latent_input,
204
+ t,
205
+ encoder_hidden_states=text_embeddings,
206
+ encoder_attention_mask=unet_attention_mask,
207
+ #added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
208
+ ).sample
209
+
210
+ if guidance_scale > 1:
211
+ noise_uncond, noise_text = noise_pred.chunk(2)
212
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
213
+
214
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
 
 
215
 
216
+ return latents
217
+
218
+
219
  def decode_latents(self, latents, output_type="pil"):
220
  """Декодирование латентов в изображения."""
 
221
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
 
 
222
  with torch.no_grad():
223
  images = self.vae.decode(latents).sample
 
 
224
  images = (images / 2 + 0.5).clamp(0, 1)
225
+
 
226
  if output_type == "pil":
227
  images = images.cpu().permute(0, 2, 3, 1).float().numpy()
228
  images = (images * 255).round().astype("uint8")
229
  return [Image.fromarray(image) for image in images]
230
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
 
231
 
232
  @torch.no_grad()
233
  def __call__(
234
  self,
235
  prompt: Optional[Union[str, List[str]]] = None,
236
+ height: int = 1280,
237
+ width: int = 1024,
238
+ num_inference_steps: int = 40,
239
+ guidance_scale: float = 4.0,
240
  latent_channels: int = 16,
241
  output_type: str = "pil",
242
  return_dict: bool = True,
 
245
  negative_prompt: Optional[Union[str, List[str]]] = None,
246
  text_embeddings: Optional[torch.FloatTensor] = None,
247
  ):
 
248
  device = self.device
249
+ generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
250
+
 
 
 
 
 
251
  if text_embeddings is None:
252
  if prompt is None and negative_prompt is None:
253
  raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
254
 
255
+ text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt(
256
+ prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype
 
 
 
257
  )
258
  else:
259
+ # Требуется, чтобы внешний text_embeddings содержал объединенные cond/uncond,
260
+ # но мы не можем получить attention_mask и pooled_embeddings.
261
+ # Для простоты лучше требовать prompt/negative_prompt.
262
+ raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.")
263
+
264
+
265
  latents = self.generate_latents(
266
  text_embeddings=text_embeddings,
267
+ attention_mask=attention_mask,
268
+ pooled_embeddings=pooled_embeddings,
269
  height=height,
270
  width=width,
271
  num_inference_steps=num_inference_steps,
 
274
  batch_size=batch_size,
275
  generator=generator
276
  )
277
+
 
278
  images = self.decode_latents(latents, output_type=output_type)
 
279
  if not return_dict:
280
  return images
 
281
  return SdxsPipelineOutput(images=images)