recoilme commited on
Commit
cdc391b
·
1 Parent(s): 56d73d2
Files changed (1) hide show
  1. pipeline_sdxs.py +65 -116
pipeline_sdxs.py CHANGED
@@ -23,99 +23,61 @@ class SdxsPipeline(DiffusionPipeline):
23
  self.max_length = max_length
24
 
25
  def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
26
- """
27
- Кодирование промптов в эмбеддинги.
28
- Поведение приближено к ручному инференсу:
29
- - padding="max_length", truncation=True, max_length=self.max_length
30
- - если negative_prompt отсутствует, возвращаем нулевой uncond с нужной формой
31
- - возврат: tensor [batch_uncond + batch_cond, seq_len, hidden_dim]
32
- где сначала идут uncond, потом cond (чтобы совпадать с concat для guidance)
33
- """
34
- if prompt is None and negative_prompt is None:
35
- raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
36
-
37
  device = device or self.device
38
- # приводим к dtype unet (важно для совместимости)
39
- dtype = dtype or next(self.unet.parameters()).dtype
40
-
41
- # нормализуем входы в списки
42
  if isinstance(prompt, str):
43
  prompt = [prompt]
44
  if isinstance(negative_prompt, str):
45
  negative_prompt = [negative_prompt]
46
-
47
- # equalize list lengths: если один из них длины 1, расширяем — как в вашем ручном коде
48
- if prompt is not None and negative_prompt is not None:
49
- if len(prompt) != len(negative_prompt):
50
- if len(negative_prompt) == 1:
51
- negative_prompt = negative_prompt * len(prompt)
52
- elif len(prompt) == 1:
53
- prompt = prompt * len(negative_prompt)
54
- else:
55
- n = min(len(prompt), len(negative_prompt))
56
- prompt = prompt[:n]
57
- negative_prompt = negative_prompt[:n]
58
-
59
- with torch.no_grad():
60
- # --- Позитивные эмбеддинги ---
61
- if prompt is not None:
62
- pos_inputs = self.tokenizer(
63
- prompt,
64
- return_tensors="pt",
65
- padding="max_length", # фиксируем длину
66
- truncation=True,
67
- max_length=self.max_length
68
- ).to(device)
69
- pos_out = self.text_encoder(
70
- pos_inputs.input_ids,
71
- attention_mask=pos_inputs.attention_mask,
72
- output_hidden_states=True
73
- )
74
- pos_embeddings = pos_out.hidden_states[-1] # [B, seq_len, dim]
75
- else:
76
- pos_embeddings = None
77
-
78
- # --- Негативные эмбеддинги ---
79
- if negative_prompt is not None:
80
- neg_inputs = self.tokenizer(
81
- negative_prompt,
82
  return_tensors="pt",
83
  padding="max_length",
84
  truncation=True,
85
- max_length=self.max_length
86
  ).to(device)
87
- neg_out = self.text_encoder(
88
- neg_inputs.input_ids,
89
- attention_mask=neg_inputs.attention_mask,
90
- output_hidden_states=True
91
- )
92
- neg_embeddings = neg_out.hidden_states[-1] # [B, seq_len, dim]
93
- else:
94
- neg_embeddings = None
95
-
96
- # Если отсутствует neg_embeddings, создаём нулевой uncond эмбеддинг
97
- if neg_embeddings is None and pos_embeddings is not None:
98
- b = pos_embeddings.shape[0]
99
- seq_len = pos_embeddings.shape[1]
100
- hid = pos_embeddings.shape[2]
101
- neg_embeddings = torch.zeros((b, seq_len, hid), device=pos_embeddings.device, dtype=pos_embeddings.dtype)
102
-
103
- # Если отсутствует pos_embeddings (маловероятно), создаём нулевой cond
104
- if pos_embeddings is None and neg_embeddings is not None:
105
- b = neg_embeddings.shape[0]
106
- seq_len = neg_embeddings.shape[1]
107
- hid = neg_embeddings.shape[2]
108
- pos_embeddings = torch.zeros((b, seq_len, hid), device=neg_embeddings.device, dtype=neg_embeddings.dtype)
109
-
110
- # Приводим dtype к нужному (например float16), чтобы совпадало с unet
111
- pos_embeddings = pos_embeddings.to(dtype=dtype, device=device)
112
- neg_embeddings = neg_embeddings.to(dtype=dtype, device=device)
113
-
114
- # Теперь формируем итоговый тензор: сначала uncond, затем cond
115
- # -- если батч >1 и один из них длиной 1, расширим до нужного размера в __call__ / generate_latents
116
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) # -> [B_uncond + B_cond, seq_len, hid]
117
 
118
- return text_embeddings # уже на device и dtype правильные
119
 
120
  @torch.no_grad()
121
  def generate_latents(
@@ -129,34 +91,20 @@ class SdxsPipeline(DiffusionPipeline):
129
  batch_size: int = 1,
130
  generator=None,
131
  ):
132
- """Генерация латентов. Поведение guidance согласовано с encode_prompt (uncond перед cond)."""
133
  device = self.device
134
- dtype = next(self.unet.parameters()).dtype
135
- do_cfg = guidance_scale > 1e-5 # true если используется guidance
136
-
137
- # text_embeddings: [B_uncond + B_cond, seq_len, hid]
138
- # ожидаем, что B_uncond == B_cond == base_batch (или оба равны 1)
139
- # разделим пополам по батчу: сначала uncond, затем cond
140
- half = text_embeddings.shape[0] // 2
141
- neg_embeds = text_embeddings[:half] # uncond
142
- pos_embeds = text_embeddings[half:] # cond
143
-
144
- # повторяем эмбеддинги, если нужно увеличить batch_size
145
- if batch_size > pos_embeds.shape[0]:
146
- reps = (batch_size + pos_embeds.shape[0] - 1) // pos_embeds.shape[0]
147
- pos_embeds = pos_embeds.repeat(reps, 1, 1)[:batch_size]
148
- neg_embeds = neg_embeds.repeat(reps, 1, 1)[:batch_size]
149
-
150
- # для guidance мы собираем [neg, pos] по батчам (concatenate)
151
- if do_cfg:
152
- text_embeddings_for_unet = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype)
153
  else:
154
- # если без guidance, просто используем pos
155
- text_embeddings_for_unet = pos_embeds.to(device=device, dtype=dtype)
156
-
157
- # Установка timesteps
158
- self.scheduler.set_timesteps(num_inference_steps, device=device)
159
-
160
  # Инициализация латентов
161
  latent_shape = (
162
  batch_size,
@@ -165,20 +113,21 @@ class SdxsPipeline(DiffusionPipeline):
165
  width // self.vae_scale_factor
166
  )
167
  latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
168
-
169
  # Процесс диффузии
170
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
171
- latent_input = torch.cat([latents, latents], dim=0) if do_cfg else latents
172
- noise_pred = self.unet(latent_input, t, encoder_hidden_states=text_embeddings_for_unet).sample
173
-
174
- if do_cfg:
175
  noise_uncond, noise_text = noise_pred.chunk(2)
176
  noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
177
-
178
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
179
-
180
  return latents
181
 
 
182
  def decode_latents(self, latents, output_type="pil"):
183
  """Декодирование латентов в изображения."""
184
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
 
23
  self.max_length = max_length
24
 
25
  def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
 
 
 
 
 
 
 
 
 
 
 
26
  device = device or self.device
27
+ dtype = dtype or torch.float16 # Явно указываем float16
28
+
29
+ # Преобразуем в списки
 
30
  if isinstance(prompt, str):
31
  prompt = [prompt]
32
  if isinstance(negative_prompt, str):
33
  negative_prompt = [negative_prompt]
34
+
35
+ # Если промпты не заданы, используем пустые эмбеддинги
36
+ if prompt is None and negative_prompt is None:
37
+ hidden_dim = 1024 # Размерность эмбеддинга Qwen3-0.6B
38
+ seq_len = 150
39
+ batch_size = 1
40
+ return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
41
+
42
+ # Токенизация с фиксированным max_length=150 и padding="max_length"
43
+ def encode_texts(texts, max_length=150):
44
+ with torch.no_grad():
45
+ toks = self.tokenizer(
46
+ texts,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return_tensors="pt",
48
  padding="max_length",
49
  truncation=True,
50
+ max_length=max_length
51
  ).to(device)
52
+ outs = self.text_encoder(**toks, output_hidden_states=True)
53
+ return outs.hidden_states[-1]
54
+
55
+ # Кодируем позитивные и негативные промпты
56
+ pos_embeddings = encode_texts(prompt) if prompt is not None else None
57
+ neg_embeddings = encode_texts(negative_prompt) if negative_prompt is not None else None
58
+
59
+ # Выравниваем размеры batch_size
60
+ batch_size = max(
61
+ pos_embeddings.shape[0] if pos_embeddings is not None else 0,
62
+ neg_embeddings.shape[0] if neg_embeddings is not None else 0
63
+ )
64
+
65
+ # Повторяем эмбеддинги по batch_size
66
+ if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
67
+ pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
68
+ if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
69
+ neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
70
+
71
+ # Конкатенируем для guidance
72
+ if pos_embeddings is not None and neg_embeddings is not None:
73
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
74
+ elif pos_embeddings is not None:
75
+ text_embeddings = pos_embeddings
76
+ else:
77
+ text_embeddings = neg_embeddings
78
+
79
+ return text_embeddings.to(device=device, dtype=dtype)
 
 
80
 
 
81
 
82
  @torch.no_grad()
83
  def generate_latents(
 
91
  batch_size: int = 1,
92
  generator=None,
93
  ):
 
94
  device = self.device
95
+ dtype = torch.float16 # Явно указываем float16
96
+
97
+ # Разделяем эмбеддинги на условные и безусловные
98
+ if guidance_scale > 1:
99
+ neg_embeds, pos_embeds = text_embeddings.chunk(2)
100
+ # Повторяем, если batch_size больше
101
+ if batch_size > pos_embeds.shape[0]:
102
+ pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
103
+ neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
104
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
 
 
 
 
 
 
 
 
 
105
  else:
106
+ text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
107
+
 
 
 
 
108
  # Инициализация латентов
109
  latent_shape = (
110
  batch_size,
 
113
  width // self.vae_scale_factor
114
  )
115
  latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
116
+
117
  # Процесс диффузии
118
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
119
+ latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
120
+ noise_pred = self.unet(latent_input, t, text_embeddings).sample
121
+
122
+ if guidance_scale > 1:
123
  noise_uncond, noise_text = noise_pred.chunk(2)
124
  noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
125
+
126
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
127
+
128
  return latents
129
 
130
+
131
  def decode_latents(self, latents, output_type="pil"):
132
  """Декодирование латентов в изображения."""
133
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor