recoilme commited on
Commit
56d73d2
·
1 Parent(s): fbe61c6
pipeline_sdxs-Copy1.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ from diffusers.utils import BaseOutput
4
+ from dataclasses import dataclass
5
+ from typing import List, Union, Optional
6
+ from PIL import Image
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ @dataclass
11
+ class SdxsPipelineOutput(BaseOutput):
12
+ images: Union[List[Image.Image], np.ndarray]
13
+
14
+ class SdxsPipeline(DiffusionPipeline):
15
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
16
+ super().__init__()
17
+ self.register_modules(
18
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
+ unet=unet, scheduler=scheduler
20
+ )
21
+ self.vae_scale_factor = 8
22
+
23
+ def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
24
+ """Кодирование текстовых промптов в эмбеддинги с выравниванием seq_len."""
25
+ if prompt is None and negative_prompt is None:
26
+ raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
27
+
28
+ device = device or self.device
29
+ dtype = dtype or next(self.unet.parameters()).dtype
30
+
31
+ # Преобразуем в списки
32
+ if isinstance(prompt, str):
33
+ prompt = [prompt]
34
+ if isinstance(negative_prompt, str):
35
+ negative_prompt = [negative_prompt]
36
+
37
+ # Выравнивание размеров позитивных/негативных списков
38
+ if prompt is not None and negative_prompt is not None:
39
+ if len(prompt) != len(negative_prompt):
40
+ if len(negative_prompt) == 1:
41
+ negative_prompt = negative_prompt * len(prompt)
42
+ elif len(prompt) == 1:
43
+ prompt = prompt * len(negative_prompt)
44
+ else:
45
+ n = min(len(prompt), len(negative_prompt))
46
+ prompt = prompt[:n]
47
+ negative_prompt = negative_prompt[:n]
48
+
49
+ with torch.no_grad():
50
+ # --- Позитивные эмбеддинги ---
51
+ if prompt is not None:
52
+ text_inputs = self.tokenizer(
53
+ prompt,
54
+ return_tensors="pt",
55
+ padding=True, # динамический паддинг
56
+ truncation=True,
57
+ max_length=512
58
+ ).to(device)
59
+ pos_embeddings = self.text_encoder(
60
+ text_inputs.input_ids,
61
+ attention_mask=text_inputs.attention_mask,
62
+ output_hidden_states=True
63
+ ).hidden_states[-1] # [batch, seq_len, dim]
64
+ else:
65
+ pos_embeddings = None
66
+
67
+ # --- Негативные эмбеддинги ---
68
+ if negative_prompt is not None:
69
+ neg_inputs = self.tokenizer(
70
+ negative_prompt,
71
+ return_tensors="pt",
72
+ padding=True,
73
+ truncation=True,
74
+ max_length=512
75
+ ).to(device)
76
+ neg_embeddings = self.text_encoder(
77
+ neg_inputs.input_ids,
78
+ attention_mask=neg_inputs.attention_mask,
79
+ output_hidden_states=True
80
+ ).hidden_states[-1] # [batch, seq_len, dim]
81
+ else:
82
+ neg_embeddings = None
83
+
84
+ # --- Выравниваем seq_len ---
85
+ if pos_embeddings is not None and neg_embeddings is not None:
86
+ max_len = max(pos_embeddings.shape[1], neg_embeddings.shape[1])
87
+ if pos_embeddings.shape[1] < max_len:
88
+ pad = torch.zeros(pos_embeddings.shape[0], max_len - pos_embeddings.shape[1], pos_embeddings.shape[2], device=pos_embeddings.device, dtype=pos_embeddings.dtype)
89
+ pos_embeddings = torch.cat([pos_embeddings, pad], dim=1)
90
+ if neg_embeddings.shape[1] < max_len:
91
+ pad = torch.zeros(neg_embeddings.shape[0], max_len - neg_embeddings.shape[1], neg_embeddings.shape[2], device=neg_embeddings.device, dtype=neg_embeddings.dtype)
92
+ neg_embeddings = torch.cat([neg_embeddings, pad], dim=1)
93
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
94
+ elif pos_embeddings is not None:
95
+ text_embeddings = pos_embeddings
96
+ else:
97
+ text_embeddings = neg_embeddings
98
+
99
+ return text_embeddings.to(device=device, dtype=dtype)
100
+
101
+
102
+ @torch.no_grad()
103
+ def generate_latents(
104
+ self,
105
+ text_embeddings,
106
+ height: int = 640,
107
+ width: int = 640,
108
+ num_inference_steps: int = 50,
109
+ guidance_scale: float = 5.0,
110
+ latent_channels: int = 16,
111
+ batch_size: int = 1,
112
+ generator=None,
113
+ ):
114
+ """Генерация латентов с уч��том любого batch_size и guidance."""
115
+ device = self.device
116
+ dtype = next(self.unet.parameters()).dtype
117
+ do_cfg = guidance_scale > 0
118
+
119
+ # Разделяем эмбеддинги на условные и безусловные для guidance
120
+ if do_cfg:
121
+ neg_embeds, pos_embeds = text_embeddings.chunk(2)
122
+ # Повторяем, если batch_size больше эмбеддингов
123
+ if batch_size > pos_embeds.shape[0]:
124
+ reps = (batch_size + pos_embeds.shape[0] - 1) // pos_embeds.shape[0]
125
+ pos_embeds = pos_embeds.repeat(reps, 1, 1)[:batch_size]
126
+ neg_embeds = neg_embeds.repeat(reps, 1, 1)[:batch_size]
127
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
128
+ else:
129
+ if batch_size > text_embeddings.shape[0]:
130
+ reps = (batch_size + text_embeddings.shape[0] - 1) // text_embeddings.shape[0]
131
+ text_embeddings = text_embeddings.repeat(reps, 1, 1)[:batch_size]
132
+
133
+ # Установка timesteps
134
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
135
+
136
+ # Инициализация латентов
137
+ latent_shape = (
138
+ batch_size,
139
+ latent_channels,
140
+ height // self.vae_scale_factor,
141
+ width // self.vae_scale_factor
142
+ )
143
+ latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
144
+
145
+ # Процесс диффузии
146
+ for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
147
+ latent_input = torch.cat([latents, latents], dim=0) if do_cfg else latents
148
+ noise_pred = self.unet(latent_input, t, text_embeddings).sample
149
+
150
+ if do_cfg:
151
+ noise_uncond, noise_text = noise_pred.chunk(2)
152
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
153
+
154
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
155
+
156
+ return latents
157
+
158
+ def decode_latents(self, latents, output_type="pil"):
159
+ """Декодирование латентов в изображения."""
160
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
161
+ with torch.no_grad():
162
+ images = self.vae.decode(latents).sample
163
+ images = (images / 2 + 0.5).clamp(0, 1)
164
+
165
+ if output_type == "pil":
166
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
167
+ images = (images * 255).round().astype("uint8")
168
+ return [Image.fromarray(image) for image in images]
169
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
170
+
171
+ @torch.no_grad()
172
+ def __call__(
173
+ self,
174
+ prompt: Optional[Union[str, List[str]]] = None,
175
+ height: int = 640,
176
+ width: int = 512,
177
+ num_inference_steps: int = 40,
178
+ guidance_scale: float = 4.0,
179
+ latent_channels: int = 16,
180
+ output_type: str = "pil",
181
+ return_dict: bool = True,
182
+ batch_size: int = 1,
183
+ seed: Optional[int] = None,
184
+ negative_prompt: Optional[Union[str, List[str]]] = None,
185
+ text_embeddings: Optional[torch.FloatTensor] = None,
186
+ ):
187
+ device = self.device
188
+ generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
189
+
190
+ if text_embeddings is None:
191
+ if prompt is None and negative_prompt is None:
192
+ raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
193
+ text_embeddings = self.encode_prompt(prompt, negative_prompt, device=device)
194
+
195
+ text_embeddings = text_embeddings.to(device)
196
+ latents = self.generate_latents(
197
+ text_embeddings=text_embeddings,
198
+ height=height,
199
+ width=width,
200
+ num_inference_steps=num_inference_steps,
201
+ guidance_scale=guidance_scale,
202
+ latent_channels=latent_channels,
203
+ batch_size=batch_size,
204
+ generator=generator
205
+ )
206
+
207
+ images = self.decode_latents(latents, output_type=output_type)
208
+ if not return_dict:
209
+ return images
210
+ return SdxsPipelineOutput(images=images)
pipeline_sdxs.py CHANGED
@@ -12,29 +12,39 @@ class SdxsPipelineOutput(BaseOutput):
12
  images: Union[List[Image.Image], np.ndarray]
13
 
14
  class SdxsPipeline(DiffusionPipeline):
15
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
16
  super().__init__()
17
  self.register_modules(
18
  vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
  unet=unet, scheduler=scheduler
20
  )
 
21
  self.vae_scale_factor = 8
 
22
 
23
  def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
24
- """Кодирование текстовых промптов в эмбеддинги с выравниванием seq_len."""
 
 
 
 
 
 
 
25
  if prompt is None and negative_prompt is None:
26
  raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
27
-
28
  device = device or self.device
 
29
  dtype = dtype or next(self.unet.parameters()).dtype
30
-
31
- # Преобразуем в списки
32
  if isinstance(prompt, str):
33
  prompt = [prompt]
34
  if isinstance(negative_prompt, str):
35
  negative_prompt = [negative_prompt]
36
-
37
- # Выравнивание размеров позитивных/негативных списков
38
  if prompt is not None and negative_prompt is not None:
39
  if len(prompt) != len(negative_prompt):
40
  if len(negative_prompt) == 1:
@@ -45,59 +55,67 @@ class SdxsPipeline(DiffusionPipeline):
45
  n = min(len(prompt), len(negative_prompt))
46
  prompt = prompt[:n]
47
  negative_prompt = negative_prompt[:n]
48
-
49
  with torch.no_grad():
50
  # --- Позитивные эмбеддинги ---
51
  if prompt is not None:
52
- text_inputs = self.tokenizer(
53
  prompt,
54
  return_tensors="pt",
55
- padding=True, # динамический паддинг
56
  truncation=True,
57
- max_length=512
58
  ).to(device)
59
- pos_embeddings = self.text_encoder(
60
- text_inputs.input_ids,
61
- attention_mask=text_inputs.attention_mask,
62
  output_hidden_states=True
63
- ).hidden_states[-1] # [batch, seq_len, dim]
 
64
  else:
65
  pos_embeddings = None
66
-
67
  # --- Негативные эмбеддинги ---
68
  if negative_prompt is not None:
69
  neg_inputs = self.tokenizer(
70
  negative_prompt,
71
  return_tensors="pt",
72
- padding=True,
73
  truncation=True,
74
- max_length=512
75
  ).to(device)
76
- neg_embeddings = self.text_encoder(
77
  neg_inputs.input_ids,
78
  attention_mask=neg_inputs.attention_mask,
79
  output_hidden_states=True
80
- ).hidden_states[-1] # [batch, seq_len, dim]
 
81
  else:
82
  neg_embeddings = None
83
-
84
- # --- Выравниваем seq_len ---
85
- if pos_embeddings is not None and neg_embeddings is not None:
86
- max_len = max(pos_embeddings.shape[1], neg_embeddings.shape[1])
87
- if pos_embeddings.shape[1] < max_len:
88
- pad = torch.zeros(pos_embeddings.shape[0], max_len - pos_embeddings.shape[1], pos_embeddings.shape[2], device=pos_embeddings.device, dtype=pos_embeddings.dtype)
89
- pos_embeddings = torch.cat([pos_embeddings, pad], dim=1)
90
- if neg_embeddings.shape[1] < max_len:
91
- pad = torch.zeros(neg_embeddings.shape[0], max_len - neg_embeddings.shape[1], neg_embeddings.shape[2], device=neg_embeddings.device, dtype=neg_embeddings.dtype)
92
- neg_embeddings = torch.cat([neg_embeddings, pad], dim=1)
93
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
94
- elif pos_embeddings is not None:
95
- text_embeddings = pos_embeddings
96
- else:
97
- text_embeddings = neg_embeddings
98
-
99
- return text_embeddings.to(device=device, dtype=dtype)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  @torch.no_grad()
103
  def generate_latents(
@@ -111,24 +129,30 @@ class SdxsPipeline(DiffusionPipeline):
111
  batch_size: int = 1,
112
  generator=None,
113
  ):
114
- """Генерация латентов с учетом любого batch_size и guidance."""
115
  device = self.device
116
  dtype = next(self.unet.parameters()).dtype
117
- do_cfg = guidance_scale > 0
118
-
119
- # Разделяем эмбеддинги на условные и безусловные для guidance
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if do_cfg:
121
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
122
- # Повторяем, если batch_size больше эмбеддингов
123
- if batch_size > pos_embeds.shape[0]:
124
- reps = (batch_size + pos_embeds.shape[0] - 1) // pos_embeds.shape[0]
125
- pos_embeds = pos_embeds.repeat(reps, 1, 1)[:batch_size]
126
- neg_embeds = neg_embeds.repeat(reps, 1, 1)[:batch_size]
127
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
128
  else:
129
- if batch_size > text_embeddings.shape[0]:
130
- reps = (batch_size + text_embeddings.shape[0] - 1) // text_embeddings.shape[0]
131
- text_embeddings = text_embeddings.repeat(reps, 1, 1)[:batch_size]
132
 
133
  # Установка timesteps
134
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -145,7 +169,7 @@ class SdxsPipeline(DiffusionPipeline):
145
  # Процесс диффузии
146
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
147
  latent_input = torch.cat([latents, latents], dim=0) if do_cfg else latents
148
- noise_pred = self.unet(latent_input, t, text_embeddings).sample
149
 
150
  if do_cfg:
151
  noise_uncond, noise_text = noise_pred.chunk(2)
@@ -190,9 +214,9 @@ class SdxsPipeline(DiffusionPipeline):
190
  if text_embeddings is None:
191
  if prompt is None and negative_prompt is None:
192
  raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
193
- text_embeddings = self.encode_prompt(prompt, negative_prompt, device=device)
194
 
195
- text_embeddings = text_embeddings.to(device)
196
  latents = self.generate_latents(
197
  text_embeddings=text_embeddings,
198
  height=height,
@@ -207,4 +231,4 @@ class SdxsPipeline(DiffusionPipeline):
207
  images = self.decode_latents(latents, output_type=output_type)
208
  if not return_dict:
209
  return images
210
- return SdxsPipelineOutput(images=images)
 
12
  images: Union[List[Image.Image], np.ndarray]
13
 
14
  class SdxsPipeline(DiffusionPipeline):
15
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None, max_length: int = 150):
16
  super().__init__()
17
  self.register_modules(
18
  vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
  unet=unet, scheduler=scheduler
20
  )
21
+ # совпадает с тем, что вы используете при ручном инференсе
22
  self.vae_scale_factor = 8
23
+ self.max_length = max_length
24
 
25
  def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
26
+ """
27
+ Кодирование промптов в эмбеддинги.
28
+ Поведение приближено к ручному инференсу:
29
+ - padding="max_length", truncation=True, max_length=self.max_length
30
+ - если negative_prompt отсутствует, возвращаем нулевой uncond с нужной формой
31
+ - возврат: tensor [batch_uncond + batch_cond, seq_len, hidden_dim]
32
+ где сначала идут uncond, потом cond (чтобы совпадать с concat для guidance)
33
+ """
34
  if prompt is None and negative_prompt is None:
35
  raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
36
+
37
  device = device or self.device
38
+ # приводим к dtype unet (важно для совместимости)
39
  dtype = dtype or next(self.unet.parameters()).dtype
40
+
41
+ # нормализуем входы в списки
42
  if isinstance(prompt, str):
43
  prompt = [prompt]
44
  if isinstance(negative_prompt, str):
45
  negative_prompt = [negative_prompt]
46
+
47
+ # equalize list lengths: если один из них длины 1, расширяем — как в вашем ручном коде
48
  if prompt is not None and negative_prompt is not None:
49
  if len(prompt) != len(negative_prompt):
50
  if len(negative_prompt) == 1:
 
55
  n = min(len(prompt), len(negative_prompt))
56
  prompt = prompt[:n]
57
  negative_prompt = negative_prompt[:n]
58
+
59
  with torch.no_grad():
60
  # --- Позитивные эмбеддинги ---
61
  if prompt is not None:
62
+ pos_inputs = self.tokenizer(
63
  prompt,
64
  return_tensors="pt",
65
+ padding="max_length", # фиксируем длину
66
  truncation=True,
67
+ max_length=self.max_length
68
  ).to(device)
69
+ pos_out = self.text_encoder(
70
+ pos_inputs.input_ids,
71
+ attention_mask=pos_inputs.attention_mask,
72
  output_hidden_states=True
73
+ )
74
+ pos_embeddings = pos_out.hidden_states[-1] # [B, seq_len, dim]
75
  else:
76
  pos_embeddings = None
77
+
78
  # --- Негативные эмбеддинги ---
79
  if negative_prompt is not None:
80
  neg_inputs = self.tokenizer(
81
  negative_prompt,
82
  return_tensors="pt",
83
+ padding="max_length",
84
  truncation=True,
85
+ max_length=self.max_length
86
  ).to(device)
87
+ neg_out = self.text_encoder(
88
  neg_inputs.input_ids,
89
  attention_mask=neg_inputs.attention_mask,
90
  output_hidden_states=True
91
+ )
92
+ neg_embeddings = neg_out.hidden_states[-1] # [B, seq_len, dim]
93
  else:
94
  neg_embeddings = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Если отсутствует neg_embeddings, создаём нулевой uncond эмбеддинг
97
+ if neg_embeddings is None and pos_embeddings is not None:
98
+ b = pos_embeddings.shape[0]
99
+ seq_len = pos_embeddings.shape[1]
100
+ hid = pos_embeddings.shape[2]
101
+ neg_embeddings = torch.zeros((b, seq_len, hid), device=pos_embeddings.device, dtype=pos_embeddings.dtype)
102
+
103
+ # Если отсутствует pos_embeddings (маловероятно), создаём нулевой cond
104
+ if pos_embeddings is None and neg_embeddings is not None:
105
+ b = neg_embeddings.shape[0]
106
+ seq_len = neg_embeddings.shape[1]
107
+ hid = neg_embeddings.shape[2]
108
+ pos_embeddings = torch.zeros((b, seq_len, hid), device=neg_embeddings.device, dtype=neg_embeddings.dtype)
109
+
110
+ # Приводим dtype к нужному (например float16), чтобы совпадало с unet
111
+ pos_embeddings = pos_embeddings.to(dtype=dtype, device=device)
112
+ neg_embeddings = neg_embeddings.to(dtype=dtype, device=device)
113
+
114
+ # Теперь формируем итоговый тензор: сначала uncond, затем cond
115
+ # -- если батч >1 и один из них длиной 1, расширим до нужного размера в __call__ / generate_latents
116
+ text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) # -> [B_uncond + B_cond, seq_len, hid]
117
+
118
+ return text_embeddings # уже на device и dtype правильные
119
 
120
  @torch.no_grad()
121
  def generate_latents(
 
129
  batch_size: int = 1,
130
  generator=None,
131
  ):
132
+ """Генерация латентов. Поведение guidance согласовано с encode_prompt (uncond перед cond)."""
133
  device = self.device
134
  dtype = next(self.unet.parameters()).dtype
135
+ do_cfg = guidance_scale > 1e-5 # true если используется guidance
136
+
137
+ # text_embeddings: [B_uncond + B_cond, seq_len, hid]
138
+ # ожидаем, что B_uncond == B_cond == base_batch (или оба равны 1)
139
+ # разделим пополам по батчу: сначала uncond, затем cond
140
+ half = text_embeddings.shape[0] // 2
141
+ neg_embeds = text_embeddings[:half] # uncond
142
+ pos_embeds = text_embeddings[half:] # cond
143
+
144
+ # повторяем эмбеддинги, если нужно увеличить batch_size
145
+ if batch_size > pos_embeds.shape[0]:
146
+ reps = (batch_size + pos_embeds.shape[0] - 1) // pos_embeds.shape[0]
147
+ pos_embeds = pos_embeds.repeat(reps, 1, 1)[:batch_size]
148
+ neg_embeds = neg_embeds.repeat(reps, 1, 1)[:batch_size]
149
+
150
+ # для guidance мы собираем [neg, pos] по батчам (concatenate)
151
  if do_cfg:
152
+ text_embeddings_for_unet = torch.cat([neg_embeds, pos_embeds], dim=0).to(device=device, dtype=dtype)
 
 
 
 
 
 
153
  else:
154
+ # если без guidance, просто используем pos
155
+ text_embeddings_for_unet = pos_embeds.to(device=device, dtype=dtype)
 
156
 
157
  # Установка timesteps
158
  self.scheduler.set_timesteps(num_inference_steps, device=device)
 
169
  # Процесс диффузии
170
  for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
171
  latent_input = torch.cat([latents, latents], dim=0) if do_cfg else latents
172
+ noise_pred = self.unet(latent_input, t, encoder_hidden_states=text_embeddings_for_unet).sample
173
 
174
  if do_cfg:
175
  noise_uncond, noise_text = noise_pred.chunk(2)
 
214
  if text_embeddings is None:
215
  if prompt is None and negative_prompt is None:
216
  raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
217
+ text_embeddings = self.encode_prompt(prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype)
218
 
219
+ # text_embeddings уже имеет структуру [B_uncond + B_cond, seq_len, hid], dtype и device совместимы
220
  latents = self.generate_latents(
221
  text_embeddings=text_embeddings,
222
  height=height,
 
231
  images = self.decode_latents(latents, output_type=output_type)
232
  if not return_dict:
233
  return images
234
+ return SdxsPipelineOutput(images=images)
samples/unet_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 01324bf55bfc0cd6404e3b0140eebad1be89309ac60c50b67495bcdad2956731
  • Pointer size: 130 Bytes
  • Size of remote file: 75.4 kB

Git LFS Details

  • SHA256: 659dae574bae66743e6160959404ebbe33d155a87159021233f04846b1f38f89
  • Pointer size: 130 Bytes
  • Size of remote file: 75 kB
samples/unet_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 0fd34542f7a2a601b3a3c4f40125f6735d730b409b73984c541f3bc2c7d66eb5
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB

Git LFS Details

  • SHA256: fcd75a85aa29103f4c3d9c346eb9ae3e51fe0be77e9435b3dc18f42aa899848c
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB
samples/unet_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 205cc5b9a8cdfc0217785062d2badf21dcc13bb54bf70eb475719a6c2ebf4cb8
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB

Git LFS Details

  • SHA256: 304f4496e8e22c7123e7db7217763fc6b52577d919aba5f0b9cbc0d6c0210c9a
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
samples/unet_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 02dd505d2336f9ead1039f86695f67c4d6d5ebd634b63a0c7f934d24365e530e
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB

Git LFS Details

  • SHA256: de0e3f38f0e44c7315095286c96b61dbeb0de5e68da18dbba0062ca2d9db25fc
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
samples/unet_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 1d2c5a283d1b324f027614f4e54e49eccd388c98f37074c9ea61f25a0b0e724f
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB

Git LFS Details

  • SHA256: 99a9d649e07cd7fcc0ee48f53b2a9dc70dafed05a4b28eaaccbf822be76897a7
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
samples/unet_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: c50ada225a9813fa594aa98d1c4143233c8ee95b085d31ede08554f6fb714489
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB

Git LFS Details

  • SHA256: 04caa48f9b8e3f2d3744e85826fbe3ec43b2fcf3916a29687bde801a26b5cf2f
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
samples/unet_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: d7bc33e02acd82607e00e9ad78e7255f9d93297c002b472fb0c60c7d83befca1
  • Pointer size: 130 Bytes
  • Size of remote file: 78.4 kB

Git LFS Details

  • SHA256: fec152a73f1eaf2807f66f40b549faed7e8a3437343a94ffa95e7ec3f91fd897
  • Pointer size: 130 Bytes
  • Size of remote file: 82.4 kB
samples/unet_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: b0a630a2cfbd5a100fa9b2cbba3247e3695401167b7ee6593fe3261b608bdd52
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB

Git LFS Details

  • SHA256: 8172fc4f29496cd4a71a3b979f8db7a0111b62218ae41c9aff2f830e40ff1f83
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
samples/unet_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 6dc3b2962f9a6b5a8f916fe3f377ee06ff479edfd8c55f2bca701fd7e7df6d3c
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB

Git LFS Details

  • SHA256: 5f93edcb50e081dd22873f6737c5e02b6d6aad0d84584295aac388c622194841
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
samples/unet_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 498f5650f01b88c5fa750e22581eee3cd174a20eb4da4b7cb7ba17edff0ba368
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB

Git LFS Details

  • SHA256: c377e615547a8cb1c3d27b97ec1c1058cb7a0ff912d7fef2e5c79aedb052096c
  • Pointer size: 131 Bytes
  • Size of remote file: 237 kB
samples/unet_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 4af2610e2c059601338a6dcf672bac6684cde956d5636ed75c5a6d5d5794e242
  • Pointer size: 131 Bytes
  • Size of remote file: 200 kB

Git LFS Details

  • SHA256: 5be162f25c0f78a4964ba0fdd96b47b8af20c57f7e807931b5a3dbcf8308b2b6
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7754efea243246c21b73b743ea55055cff9ed385f22d119ee489931185366cf1
3
- size 8316949
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f7ceb281d9d78b8ed0085e763df363b106df049ee6830bc40d84e6a1c25b34
3
+ size 8326857
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3debe901242f713f340604da75758d75924154f8a87d65221eee85a2bcef6f8c
3
  size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f51c65967bb570338af3731ea474bbf1d182549ccd33c6136b531a5e383c57e7
3
  size 6184944280