recoilme commited on
Commit
4961b95
·
1 Parent(s): 4d26529
pipeline_sdxs-Copy1.py DELETED
@@ -1,210 +0,0 @@
1
- from diffusers import DiffusionPipeline
2
- import torch
3
- from diffusers.utils import BaseOutput
4
- from dataclasses import dataclass
5
- from typing import List, Union, Optional
6
- from PIL import Image
7
- import numpy as np
8
- from tqdm import tqdm
9
-
10
- @dataclass
11
- class SdxsPipelineOutput(BaseOutput):
12
- images: Union[List[Image.Image], np.ndarray]
13
-
14
- class SdxsPipeline(DiffusionPipeline):
15
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None):
16
- super().__init__()
17
- self.register_modules(
18
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
- unet=unet, scheduler=scheduler
20
- )
21
- self.vae_scale_factor = 8
22
-
23
- def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
24
- """Кодирование текстовых промптов в эмбеддинги с выравниванием seq_len."""
25
- if prompt is None and negative_prompt is None:
26
- raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt")
27
-
28
- device = device or self.device
29
- dtype = dtype or next(self.unet.parameters()).dtype
30
-
31
- # Преобразуем в списки
32
- if isinstance(prompt, str):
33
- prompt = [prompt]
34
- if isinstance(negative_prompt, str):
35
- negative_prompt = [negative_prompt]
36
-
37
- # Выравнивание размеров позитивных/негативных списков
38
- if prompt is not None and negative_prompt is not None:
39
- if len(prompt) != len(negative_prompt):
40
- if len(negative_prompt) == 1:
41
- negative_prompt = negative_prompt * len(prompt)
42
- elif len(prompt) == 1:
43
- prompt = prompt * len(negative_prompt)
44
- else:
45
- n = min(len(prompt), len(negative_prompt))
46
- prompt = prompt[:n]
47
- negative_prompt = negative_prompt[:n]
48
-
49
- with torch.no_grad():
50
- # --- Позитивные эмбеддинги ---
51
- if prompt is not None:
52
- text_inputs = self.tokenizer(
53
- prompt,
54
- return_tensors="pt",
55
- padding=True, # динамический паддинг
56
- truncation=True,
57
- max_length=512
58
- ).to(device)
59
- pos_embeddings = self.text_encoder(
60
- text_inputs.input_ids,
61
- attention_mask=text_inputs.attention_mask,
62
- output_hidden_states=True
63
- ).hidden_states[-1] # [batch, seq_len, dim]
64
- else:
65
- pos_embeddings = None
66
-
67
- # --- Негативные эмбеддинги ---
68
- if negative_prompt is not None:
69
- neg_inputs = self.tokenizer(
70
- negative_prompt,
71
- return_tensors="pt",
72
- padding=True,
73
- truncation=True,
74
- max_length=512
75
- ).to(device)
76
- neg_embeddings = self.text_encoder(
77
- neg_inputs.input_ids,
78
- attention_mask=neg_inputs.attention_mask,
79
- output_hidden_states=True
80
- ).hidden_states[-1] # [batch, seq_len, dim]
81
- else:
82
- neg_embeddings = None
83
-
84
- # --- Выравниваем seq_len ---
85
- if pos_embeddings is not None and neg_embeddings is not None:
86
- max_len = max(pos_embeddings.shape[1], neg_embeddings.shape[1])
87
- if pos_embeddings.shape[1] < max_len:
88
- pad = torch.zeros(pos_embeddings.shape[0], max_len - pos_embeddings.shape[1], pos_embeddings.shape[2], device=pos_embeddings.device, dtype=pos_embeddings.dtype)
89
- pos_embeddings = torch.cat([pos_embeddings, pad], dim=1)
90
- if neg_embeddings.shape[1] < max_len:
91
- pad = torch.zeros(neg_embeddings.shape[0], max_len - neg_embeddings.shape[1], neg_embeddings.shape[2], device=neg_embeddings.device, dtype=neg_embeddings.dtype)
92
- neg_embeddings = torch.cat([neg_embeddings, pad], dim=1)
93
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
94
- elif pos_embeddings is not None:
95
- text_embeddings = pos_embeddings
96
- else:
97
- text_embeddings = neg_embeddings
98
-
99
- return text_embeddings.to(device=device, dtype=dtype)
100
-
101
-
102
- @torch.no_grad()
103
- def generate_latents(
104
- self,
105
- text_embeddings,
106
- height: int = 640,
107
- width: int = 640,
108
- num_inference_steps: int = 50,
109
- guidance_scale: float = 5.0,
110
- latent_channels: int = 16,
111
- batch_size: int = 1,
112
- generator=None,
113
- ):
114
- """Генерация латентов с уч��том любого batch_size и guidance."""
115
- device = self.device
116
- dtype = next(self.unet.parameters()).dtype
117
- do_cfg = guidance_scale > 0
118
-
119
- # Разделяем эмбеддинги на условные и безусловные для guidance
120
- if do_cfg:
121
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
122
- # Повторяем, если batch_size больше эмбеддингов
123
- if batch_size > pos_embeds.shape[0]:
124
- reps = (batch_size + pos_embeds.shape[0] - 1) // pos_embeds.shape[0]
125
- pos_embeds = pos_embeds.repeat(reps, 1, 1)[:batch_size]
126
- neg_embeds = neg_embeds.repeat(reps, 1, 1)[:batch_size]
127
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
128
- else:
129
- if batch_size > text_embeddings.shape[0]:
130
- reps = (batch_size + text_embeddings.shape[0] - 1) // text_embeddings.shape[0]
131
- text_embeddings = text_embeddings.repeat(reps, 1, 1)[:batch_size]
132
-
133
- # Установка timesteps
134
- self.scheduler.set_timesteps(num_inference_steps, device=device)
135
-
136
- # Инициализация латентов
137
- latent_shape = (
138
- batch_size,
139
- latent_channels,
140
- height // self.vae_scale_factor,
141
- width // self.vae_scale_factor
142
- )
143
- latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
144
-
145
- # Процесс диффузии
146
- for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
147
- latent_input = torch.cat([latents, latents], dim=0) if do_cfg else latents
148
- noise_pred = self.unet(latent_input, t, text_embeddings).sample
149
-
150
- if do_cfg:
151
- noise_uncond, noise_text = noise_pred.chunk(2)
152
- noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
153
-
154
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
155
-
156
- return latents
157
-
158
- def decode_latents(self, latents, output_type="pil"):
159
- """Декодирование латентов в изображения."""
160
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
161
- with torch.no_grad():
162
- images = self.vae.decode(latents).sample
163
- images = (images / 2 + 0.5).clamp(0, 1)
164
-
165
- if output_type == "pil":
166
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
167
- images = (images * 255).round().astype("uint8")
168
- return [Image.fromarray(image) for image in images]
169
- return images.cpu().permute(0, 2, 3, 1).float().numpy()
170
-
171
- @torch.no_grad()
172
- def __call__(
173
- self,
174
- prompt: Optional[Union[str, List[str]]] = None,
175
- height: int = 640,
176
- width: int = 512,
177
- num_inference_steps: int = 40,
178
- guidance_scale: float = 4.0,
179
- latent_channels: int = 16,
180
- output_type: str = "pil",
181
- return_dict: bool = True,
182
- batch_size: int = 1,
183
- seed: Optional[int] = None,
184
- negative_prompt: Optional[Union[str, List[str]]] = None,
185
- text_embeddings: Optional[torch.FloatTensor] = None,
186
- ):
187
- device = self.device
188
- generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
189
-
190
- if text_embeddings is None:
191
- if prompt is None and negative_prompt is None:
192
- raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
193
- text_embeddings = self.encode_prompt(prompt, negative_prompt, device=device)
194
-
195
- text_embeddings = text_embeddings.to(device)
196
- latents = self.generate_latents(
197
- text_embeddings=text_embeddings,
198
- height=height,
199
- width=width,
200
- num_inference_steps=num_inference_steps,
201
- guidance_scale=guidance_scale,
202
- latent_channels=latent_channels,
203
- batch_size=batch_size,
204
- generator=generator
205
- )
206
-
207
- images = self.decode_latents(latents, output_type=output_type)
208
- if not return_dict:
209
- return images
210
- return SdxsPipelineOutput(images=images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
samples/unet_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 659dae574bae66743e6160959404ebbe33d155a87159021233f04846b1f38f89
  • Pointer size: 130 Bytes
  • Size of remote file: 75 kB

Git LFS Details

  • SHA256: eeadb86468f7914072fdddc17c50dfbfca87de678af713ffc84a2a808f4e6f1a
  • Pointer size: 130 Bytes
  • Size of remote file: 73.4 kB
samples/unet_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: fcd75a85aa29103f4c3d9c346eb9ae3e51fe0be77e9435b3dc18f42aa899848c
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB

Git LFS Details

  • SHA256: 6fd3dfda1e606fec1288d7c25ff4907bf6e3a69a8982dd5ac2c08f94bf562322
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
samples/unet_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 304f4496e8e22c7123e7db7217763fc6b52577d919aba5f0b9cbc0d6c0210c9a
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB

Git LFS Details

  • SHA256: 367c70edad4771937b88785d875b15ae81320a5e4a69eeee7f9e078a28694c13
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
samples/unet_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: de0e3f38f0e44c7315095286c96b61dbeb0de5e68da18dbba0062ca2d9db25fc
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB

Git LFS Details

  • SHA256: 1ba222448969382144515f16388f1e32f373aa44a6fcae83fce27a9c59903ff0
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
samples/unet_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 99a9d649e07cd7fcc0ee48f53b2a9dc70dafed05a4b28eaaccbf822be76897a7
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB

Git LFS Details

  • SHA256: 37a288d1b3baebfee8f3966d86fb2e0e08121d83272b37b70735cfeec90cc876
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
samples/unet_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 04caa48f9b8e3f2d3744e85826fbe3ec43b2fcf3916a29687bde801a26b5cf2f
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB

Git LFS Details

  • SHA256: af35b0c83345f526e31e0362f6b70c52bd826ad3b4c47f595628a955164061eb
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
samples/unet_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: fec152a73f1eaf2807f66f40b549faed7e8a3437343a94ffa95e7ec3f91fd897
  • Pointer size: 130 Bytes
  • Size of remote file: 82.4 kB

Git LFS Details

  • SHA256: 1a715b290e41d573f15960686629e0afbd80fcdd44ce9d37971c31e25a488dc0
  • Pointer size: 130 Bytes
  • Size of remote file: 74.6 kB
samples/unet_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 8172fc4f29496cd4a71a3b979f8db7a0111b62218ae41c9aff2f830e40ff1f83
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB

Git LFS Details

  • SHA256: 0212c09329aaf2076d300ac283fff8c6a00516d98ccca765ee87c0065920f1e6
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/unet_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 5f93edcb50e081dd22873f6737c5e02b6d6aad0d84584295aac388c622194841
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB

Git LFS Details

  • SHA256: 6c01319ad5d34900b611b7d832534b03650fee424a777c5444694475d9668f3a
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
samples/unet_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: c377e615547a8cb1c3d27b97ec1c1058cb7a0ff912d7fef2e5c79aedb052096c
  • Pointer size: 131 Bytes
  • Size of remote file: 237 kB

Git LFS Details

  • SHA256: 8905a8c78d65192b578df1dfe45701b652285503a0b07cb921b6547b4a10840f
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
samples/unet_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 5be162f25c0f78a4964ba0fdd96b47b8af20c57f7e807931b5a3dbcf8308b2b6
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB

Git LFS Details

  • SHA256: a48badb72d2c22cd76dc413a62d95dd0b04bfacf9fa6035c4eaed1acf7e23ae9
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f51c65967bb570338af3731ea474bbf1d182549ccd33c6136b531a5e383c57e7
3
  size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:589662b8b18471fa1f1868b0cba4edadfe05325784b987eab64fb5915c1546d6
3
  size 6184944280