1te
Browse files- girl.jpg +2 -2
- media/result_grid.jpg +2 -2
- model_index.json +2 -2
- pipeline_sdxs-Copy1.py +372 -0
- pipeline_sdxs.py +8 -99
- test.ipynb +2 -2
- text_encoder/config.json +2 -2
- {text_encoder2 → text_encoder}/generation_config.json +0 -0
- text_encoder/model.safetensors +2 -2
- text_encoder2/config.json +0 -3
- text_encoder2/model.safetensors +0 -3
- tmp/config.json +0 -3
- tmp/diffusion_pytorch_model.safetensors +0 -3
- {tokenizer2 → tokenizer}/chat_template.jinja +0 -0
- tokenizer/merges.txt +0 -0
- {tokenizer2 → tokenizer}/preprocessor_config.json +0 -0
- {tokenizer2 → tokenizer}/processor_config.json +0 -0
- tokenizer/special_tokens_map.json +0 -3
- {tokenizer2 → tokenizer}/tokenizer.json +0 -0
- tokenizer/tokenizer_config.json +2 -2
- tokenizer/vocab.json +2 -2
- tokenizer2/merges.txt +0 -0
- tokenizer2/tokenizer_config.json +0 -3
- tokenizer2/vocab.json +0 -3
- train.py +5 -7
- unet/diffusion_pytorch_model.safetensors +1 -1
- unet_old/config.json +0 -3
- unet_old/diffusion_pytorch_model.safetensors +0 -3
girl.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
media/result_grid.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
model_index.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1b6ce208190165975458f502eb3e7ad6d4a5dd54a507f8dd727e636e363ae93
|
| 3 |
+
size 428
|
pipeline_sdxs-Copy1.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import List, Union, Optional, Tuple
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from diffusers import DiffusionPipeline
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class SdxsPipelineOutput(BaseOutput):
|
| 14 |
+
images: Union[List[Image.Image], np.ndarray]
|
| 15 |
+
prompt: Optional[Union[str, List[str]]] = None # Возврат улучшенного промпта
|
| 16 |
+
|
| 17 |
+
class SdxsPipeline(DiffusionPipeline):
|
| 18 |
+
def __init__(self, vae, text_encoder, text_encoder2, tokenizer, tokenizer2, unet, scheduler):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.register_modules(
|
| 21 |
+
vae=vae,
|
| 22 |
+
text_encoder=text_encoder,
|
| 23 |
+
text_encoder2=text_encoder2,
|
| 24 |
+
tokenizer=tokenizer,
|
| 25 |
+
tokenizer2=tokenizer2,
|
| 26 |
+
unet=unet,
|
| 27 |
+
scheduler=scheduler
|
| 28 |
+
)
|
| 29 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 30 |
+
mean = getattr(self.vae.config, "latents_mean", None)
|
| 31 |
+
std = getattr(self.vae.config, "latents_std", None)
|
| 32 |
+
if mean is not None and std is not None:
|
| 33 |
+
self.vae_latents_std = torch.tensor(std, device=self.unet.device, dtype=self.unet.dtype).view(1, len(std), 1, 1)
|
| 34 |
+
self.vae_latents_mean = torch.tensor(mean, device=self.unet.device, dtype=self.unet.dtype).view(1, len(mean), 1, 1)
|
| 35 |
+
|
| 36 |
+
def preprocess_image(self, image: Image.Image, width: int, height: int):
|
| 37 |
+
"""Ресайз и центрированный кроп изображения для асимметричного VAE."""
|
| 38 |
+
# Для энкодера с масштабом 8
|
| 39 |
+
target_height = ((height // self.vae_scale_factor) * self.vae_scale_factor)
|
| 40 |
+
target_width = ((width // self.vae_scale_factor) * self.vae_scale_factor)
|
| 41 |
+
|
| 42 |
+
w, h = image.size
|
| 43 |
+
aspect_ratio = target_width / target_height
|
| 44 |
+
|
| 45 |
+
if w / h > aspect_ratio:
|
| 46 |
+
new_w = int(h * aspect_ratio)
|
| 47 |
+
left = (w - new_w) // 2
|
| 48 |
+
image = image.crop((left, 0, left + new_w, h))
|
| 49 |
+
else:
|
| 50 |
+
new_h = int(w / aspect_ratio)
|
| 51 |
+
top = (h - new_h) // 2
|
| 52 |
+
image = image.crop((0, top, w, top + new_h))
|
| 53 |
+
|
| 54 |
+
image = image.resize((target_width, target_height), resample=Image.LANCZOS)
|
| 55 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 56 |
+
image = image[None].transpose(0, 3, 1, 2) # [1, C, H, W]
|
| 57 |
+
image = torch.from_numpy(image)
|
| 58 |
+
return 2.0 * image - 1.0 # [-1, 1]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def encode_prompt(self, prompt, negative_prompt, device, dtype):
|
| 62 |
+
def get_single_encode(texts):
|
| 63 |
+
if not texts:
|
| 64 |
+
texts = [""]
|
| 65 |
+
elif isinstance(texts, str):
|
| 66 |
+
texts = [texts]
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
toks = self.tokenizer(
|
| 70 |
+
texts,
|
| 71 |
+
padding="max_length",
|
| 72 |
+
max_length=self.text_encoder.config.max_position_embeddings,
|
| 73 |
+
truncation=True,
|
| 74 |
+
return_tensors="pt"
|
| 75 |
+
).to(device)
|
| 76 |
+
|
| 77 |
+
outputs = self.text_encoder(
|
| 78 |
+
input_ids=toks.input_ids,
|
| 79 |
+
attention_mask=toks.attention_mask,
|
| 80 |
+
output_hidden_states=True
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# 1. Берем -2 слой [Batch, Seq, Dim]
|
| 84 |
+
hidden = outputs.hidden_states[-2]
|
| 85 |
+
|
| 86 |
+
# 2. Достаем pooled вектор (последний токен) [Batch, Dim]
|
| 87 |
+
seq_lens = toks.attention_mask.sum(dim=1) - 1
|
| 88 |
+
pooled = hidden[torch.arange(hidden.shape[0]), seq_lens.clamp(min=0)]
|
| 89 |
+
|
| 90 |
+
# 3. Нормализация
|
| 91 |
+
norm = self.text_encoder.text_model.final_layer_norm
|
| 92 |
+
hidden = norm(hidden)
|
| 93 |
+
pooled = norm(pooled)
|
| 94 |
+
|
| 95 |
+
# 4. Объединяем в матрицу: Пулед (как 1-й токен) + остальные токены
|
| 96 |
+
# pooled.unsqueeze(1) делает [Batch, 1, Dim]
|
| 97 |
+
embeds = torch.cat([pooled.unsqueeze(1), hidden], dim=1)
|
| 98 |
+
|
| 99 |
+
# 5. Расширяем маску для нового токена (добавляем единицы спереди)
|
| 100 |
+
ones = torch.ones((toks.attention_mask.shape[0], 1), dtype=toks.attention_mask.dtype, device=device)
|
| 101 |
+
mask = torch.cat([ones, toks.attention_mask], dim=1)
|
| 102 |
+
|
| 103 |
+
return embeds, mask, pooled
|
| 104 |
+
|
| 105 |
+
def get_pooled_encode(texts):
|
| 106 |
+
if texts is None:
|
| 107 |
+
texts = ""
|
| 108 |
+
|
| 109 |
+
if isinstance(texts, str):
|
| 110 |
+
texts = [texts]
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
# 1. Собираем текстовые промпты оборачивая их в Chat Template
|
| 114 |
+
formatted_prompts = []
|
| 115 |
+
for t in texts:
|
| 116 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 117 |
+
res_text = self.tokenizer2.apply_chat_template(
|
| 118 |
+
messages,
|
| 119 |
+
add_generation_prompt=True,
|
| 120 |
+
tokenize=False
|
| 121 |
+
)
|
| 122 |
+
formatted_prompts.append(res_text)
|
| 123 |
+
|
| 124 |
+
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 125 |
+
toks = self.tokenizer2(
|
| 126 |
+
formatted_prompts,
|
| 127 |
+
padding="max_length",
|
| 128 |
+
max_length=self.text_encoder.config.max_position_embeddings,
|
| 129 |
+
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 130 |
+
return_tensors="pt"
|
| 131 |
+
).to(device)
|
| 132 |
+
|
| 133 |
+
# 3. Прогоняем через модель
|
| 134 |
+
outputs = self.text_encoder2(
|
| 135 |
+
input_ids=toks.input_ids,
|
| 136 |
+
attention_mask=toks.attention_mask,
|
| 137 |
+
output_hidden_states=True
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
layer_index = -2
|
| 141 |
+
last_hidden = outputs.hidden_states[layer_index]
|
| 142 |
+
seq_len = toks.attention_mask.sum(dim=1) - 1
|
| 143 |
+
pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
|
| 144 |
+
|
| 145 |
+
return pooled
|
| 146 |
+
def get_encode(texts):
|
| 147 |
+
if texts is None:
|
| 148 |
+
texts = ""
|
| 149 |
+
|
| 150 |
+
if isinstance(texts, str):
|
| 151 |
+
texts = [texts]
|
| 152 |
+
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
# 1. Собираем текстовые промпты оборачивая их в Chat Template
|
| 155 |
+
formatted_prompts = []
|
| 156 |
+
for t in texts:
|
| 157 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 158 |
+
res_text = self.tokenizer2.apply_chat_template(
|
| 159 |
+
messages,
|
| 160 |
+
add_generation_prompt=True,
|
| 161 |
+
tokenize=False
|
| 162 |
+
)
|
| 163 |
+
formatted_prompts.append(res_text)
|
| 164 |
+
|
| 165 |
+
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 166 |
+
toks = self.tokenizer2(
|
| 167 |
+
formatted_prompts,
|
| 168 |
+
padding="max_length",
|
| 169 |
+
max_length=self.text_encoder.config.max_position_embeddings,
|
| 170 |
+
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 171 |
+
return_tensors="pt"
|
| 172 |
+
).to(device)
|
| 173 |
+
|
| 174 |
+
# 3. Прогоняем через модель
|
| 175 |
+
outputs = self.text_encoder2(
|
| 176 |
+
input_ids=toks.input_ids,
|
| 177 |
+
attention_mask=toks.attention_mask,
|
| 178 |
+
output_hidden_states=True
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
layer_index = -2
|
| 182 |
+
last_hidden = outputs.hidden_states[layer_index]
|
| 183 |
+
seq_len = toks.attention_mask.sum(dim=1) - 1
|
| 184 |
+
pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
|
| 185 |
+
|
| 186 |
+
return last_hidden, toks.attention_mask, pooled
|
| 187 |
+
|
| 188 |
+
#pos_embeds, pos_mask, pooled_pos = get_single_encode(prompt)
|
| 189 |
+
#neg_embeds, neg_mask, pooled_neg = get_single_encode(negative_prompt)
|
| 190 |
+
# 768 + 2048
|
| 191 |
+
#pos_pooled = get_pooled_encode(prompt) #torch.cat([pooled_pos, get_pooled_encode(prompt)], dim=1)
|
| 192 |
+
#neg_pooled = get_pooled_encode(negative_prompt) #torch.cat([pooled_neg, get_pooled_encode(negative_prompt)], dim=1)
|
| 193 |
+
pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
|
| 194 |
+
neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
|
| 195 |
+
|
| 196 |
+
batch_size = pos_embeds.shape[0]
|
| 197 |
+
if neg_embeds.shape[0] != batch_size:
|
| 198 |
+
neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
|
| 199 |
+
neg_mask = neg_mask.repeat(batch_size, 1)
|
| 200 |
+
neg_pooled = neg_pooled.repeat(batch_size, 1)
|
| 201 |
+
|
| 202 |
+
if pos_pooled.shape[0] != batch_size:
|
| 203 |
+
pos_pooled = pos_pooled.repeat(batch_size, 1)
|
| 204 |
+
|
| 205 |
+
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
|
| 206 |
+
final_mask = torch.cat([neg_mask, pos_mask], dim=0)
|
| 207 |
+
pooled_embeds = torch.cat([neg_pooled, pos_pooled], dim=0)
|
| 208 |
+
|
| 209 |
+
return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64), pooled_embeds.to(dtype=dtype)
|
| 210 |
+
|
| 211 |
+
@torch.no_grad()
|
| 212 |
+
def __call__(
|
| 213 |
+
self,
|
| 214 |
+
prompt: Union[str, List[str]],
|
| 215 |
+
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
| 216 |
+
coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
|
| 217 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 218 |
+
height: int = 1024,
|
| 219 |
+
width: int = 1024,
|
| 220 |
+
num_inference_steps: int = 40,
|
| 221 |
+
guidance_scale: float = 4.0,
|
| 222 |
+
generator: Optional[torch.Generator] = None,
|
| 223 |
+
seed: Optional[int] = None,
|
| 224 |
+
output_type: str = "pil",
|
| 225 |
+
return_dict: bool = True,
|
| 226 |
+
refine_prompt: bool = False, # Флаг рефайна!
|
| 227 |
+
# structure_preservation оставляем для совместимости, но теперь он почти не нужен
|
| 228 |
+
structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
|
| 229 |
+
**kwargs,
|
| 230 |
+
):
|
| 231 |
+
device = self.device
|
| 232 |
+
dtype = self.unet.dtype
|
| 233 |
+
|
| 234 |
+
if generator is None and seed is not None:
|
| 235 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 236 |
+
|
| 237 |
+
# ==================== REFINE PROMPT (INLINE) ====================
|
| 238 |
+
if refine_prompt and prompt:
|
| 239 |
+
sys_msg = (
|
| 240 |
+
"You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
|
| 241 |
+
"**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
|
| 242 |
+
"Output **only** the final revised prompt in **English**, with absolutely no commentary.\n Don't use cliches like warm,soft,vibrant, wildflowers. Be creative "
|
| 243 |
+
"User input prompt: "
|
| 244 |
+
)
|
| 245 |
+
prompts_list = [prompt] if isinstance(prompt, str) else prompt
|
| 246 |
+
refined_list = []
|
| 247 |
+
|
| 248 |
+
for p in prompts_list:
|
| 249 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": sys_msg + p}]}]
|
| 250 |
+
|
| 251 |
+
# Используем Qwen-Instruct формат (apply_chat_template сам подставит system/user/assistant токены)
|
| 252 |
+
inputs = self.tokenizer2.apply_chat_template(
|
| 253 |
+
messages,
|
| 254 |
+
tokenize=True,
|
| 255 |
+
add_generation_prompt=True,
|
| 256 |
+
return_dict=True,
|
| 257 |
+
return_tensors="pt"
|
| 258 |
+
).to(device)
|
| 259 |
+
|
| 260 |
+
generated_ids = self.text_encoder2.generate(
|
| 261 |
+
**inputs, max_new_tokens=self.text_encoder.config.max_position_embeddings, do_sample=True,temperature = 0.7
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Обрезаем входные токены из ответа
|
| 265 |
+
generated_ids_trimmed = [
|
| 266 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 267 |
+
]
|
| 268 |
+
output_text = self.tokenizer2.batch_decode(
|
| 269 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 270 |
+
)
|
| 271 |
+
refined_list.append(output_text)
|
| 272 |
+
|
| 273 |
+
prompt = refined_list[0] if isinstance(prompt, str) else refined_list
|
| 274 |
+
|
| 275 |
+
# ==================== ENCODE PROMPTS ====================
|
| 276 |
+
text_embeddings, attention_mask, pooled_embeds = self.encode_prompt(
|
| 277 |
+
prompt, negative_prompt, device, dtype
|
| 278 |
+
)
|
| 279 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 280 |
+
|
| 281 |
+
# 2. Scheduler timesteps
|
| 282 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 283 |
+
timesteps = self.scheduler.timesteps
|
| 284 |
+
|
| 285 |
+
# ==================== TIME IDS =======================================
|
| 286 |
+
# time_ids должен иметь ТОТ ЖЕ batch-размер, что и pooled_embeds и text_embeddings
|
| 287 |
+
# (в твоём encode_prompt они всегда удваиваются из-за CFG)
|
| 288 |
+
time_ids = torch.zeros(
|
| 289 |
+
pooled_embeds.shape[0], # ← вот это главное
|
| 290 |
+
6,
|
| 291 |
+
device=device,
|
| 292 |
+
dtype=torch.long
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# ==================== IMG2IMG БЛОК (НОВАЯ ВЕРСИЯ) ====================
|
| 296 |
+
if image is not None:
|
| 297 |
+
# --- Подготовка изображения ---
|
| 298 |
+
if isinstance(image, Image.Image):
|
| 299 |
+
image_tensor = self.preprocess_image(image, width, height).to(device, self.vae.dtype)
|
| 300 |
+
else:
|
| 301 |
+
image_tensor = self.preprocess_image(image[0], width, height).to(device, self.vae.dtype)
|
| 302 |
+
|
| 303 |
+
# --- Кодируем в latent ---
|
| 304 |
+
latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
|
| 305 |
+
latents_clean = (latents_clean - self.vae_latents_mean.to(device, self.vae.dtype)) / self.vae_latents_std.to(device, self.vae.dtype)
|
| 306 |
+
latents_clean = latents_clean.to(dtype)
|
| 307 |
+
|
| 308 |
+
# --- Добавляем шум по Rectified Flow формуле ---
|
| 309 |
+
noise = torch.randn_like(latents_clean)
|
| 310 |
+
|
| 311 |
+
# coef = strength (0.0 → оригинал, 1.0 → чистый шум)
|
| 312 |
+
sigma = coef # в Flow Matching sigma = t
|
| 313 |
+
if hasattr(self.scheduler, "sigma_shift"): # если есть shift (Flux-style)
|
| 314 |
+
sigma = self.scheduler.sigma_shift(sigma)
|
| 315 |
+
|
| 316 |
+
latents = (1.0 - sigma) * latents_clean + sigma * noise
|
| 317 |
+
|
| 318 |
+
# Обрезаем timesteps начиная с текущего sigma
|
| 319 |
+
init_timestep = int(num_inference_steps * coef)
|
| 320 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 321 |
+
timesteps = timesteps[t_start:]
|
| 322 |
+
|
| 323 |
+
else:
|
| 324 |
+
# txt2img
|
| 325 |
+
latent_h = height // self.vae_scale_factor
|
| 326 |
+
latent_w = width // self.vae_scale_factor
|
| 327 |
+
|
| 328 |
+
latents = torch.randn(
|
| 329 |
+
(batch_size, self.unet.config.in_channels, latent_h, latent_w),
|
| 330 |
+
generator=generator, device=device, dtype=dtype
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# ==================== DENOISING LOOP (одинаковый для txt2img и img2img) ====================
|
| 334 |
+
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 335 |
+
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 336 |
+
|
| 337 |
+
model_out = self.unet(
|
| 338 |
+
latent_model_input,
|
| 339 |
+
t,
|
| 340 |
+
encoder_hidden_states=text_embeddings,
|
| 341 |
+
encoder_attention_mask=attention_mask,
|
| 342 |
+
#added_cond_kwargs=added_cond_kwargs,
|
| 343 |
+
added_cond_kwargs={"text_embeds": pooled_embeds,"time_ids": time_ids},
|
| 344 |
+
return_dict=False,
|
| 345 |
+
)[0]
|
| 346 |
+
|
| 347 |
+
if guidance_scale > 1.0:
|
| 348 |
+
flow_uncond, flow_cond = model_out.chunk(2)
|
| 349 |
+
model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
|
| 350 |
+
|
| 351 |
+
# Важно: используем scheduler.step — он сам знает, что делать с velocity
|
| 352 |
+
latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
|
| 353 |
+
|
| 354 |
+
# ==================== DECODE ====================
|
| 355 |
+
if output_type == "latent":
|
| 356 |
+
if not return_dict: return (latents, prompt)
|
| 357 |
+
return SdxsPipelineOutput(images=latents, prompt=prompt)
|
| 358 |
+
|
| 359 |
+
latents = latents * self.vae_latents_std.to(device, self.vae.dtype) + self.vae_latents_mean.to(device, self.vae.dtype)
|
| 360 |
+
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 361 |
+
|
| 362 |
+
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 363 |
+
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 364 |
+
|
| 365 |
+
if output_type == "pil":
|
| 366 |
+
images = [(Image.fromarray((img * 255).round().astype("uint8"))) for img in image_np]
|
| 367 |
+
else:
|
| 368 |
+
images = image_np
|
| 369 |
+
|
| 370 |
+
if not return_dict:
|
| 371 |
+
return (images, prompt)
|
| 372 |
+
return SdxsPipelineOutput(images=images, prompt=prompt)
|
pipeline_sdxs.py
CHANGED
|
@@ -7,22 +7,20 @@ from dataclasses import dataclass
|
|
| 7 |
from diffusers import DiffusionPipeline
|
| 8 |
from diffusers.utils import BaseOutput
|
| 9 |
from tqdm import tqdm
|
| 10 |
-
from transformers import
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class SdxsPipelineOutput(BaseOutput):
|
| 14 |
images: Union[List[Image.Image], np.ndarray]
|
| 15 |
-
prompt: Optional[Union[str, List[str]]] = None
|
| 16 |
|
| 17 |
class SdxsPipeline(DiffusionPipeline):
|
| 18 |
-
def __init__(self, vae, text_encoder,
|
| 19 |
super().__init__()
|
| 20 |
self.register_modules(
|
| 21 |
vae=vae,
|
| 22 |
text_encoder=text_encoder,
|
| 23 |
-
text_encoder2=text_encoder2,
|
| 24 |
tokenizer=tokenizer,
|
| 25 |
-
tokenizer2=tokenizer2,
|
| 26 |
unet=unet,
|
| 27 |
scheduler=scheduler
|
| 28 |
)
|
|
@@ -59,90 +57,6 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 59 |
|
| 60 |
|
| 61 |
def encode_prompt(self, prompt, negative_prompt, device, dtype):
|
| 62 |
-
def get_single_encode(texts):
|
| 63 |
-
if not texts:
|
| 64 |
-
texts = [""]
|
| 65 |
-
elif isinstance(texts, str):
|
| 66 |
-
texts = [texts]
|
| 67 |
-
|
| 68 |
-
with torch.no_grad():
|
| 69 |
-
toks = self.tokenizer(
|
| 70 |
-
texts,
|
| 71 |
-
padding="max_length",
|
| 72 |
-
max_length=self.text_encoder.config.max_position_embeddings,
|
| 73 |
-
truncation=True,
|
| 74 |
-
return_tensors="pt"
|
| 75 |
-
).to(device)
|
| 76 |
-
|
| 77 |
-
outputs = self.text_encoder(
|
| 78 |
-
input_ids=toks.input_ids,
|
| 79 |
-
attention_mask=toks.attention_mask,
|
| 80 |
-
output_hidden_states=True
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
# 1. Берем -2 слой [Batch, Seq, Dim]
|
| 84 |
-
hidden = outputs.hidden_states[-2]
|
| 85 |
-
|
| 86 |
-
# 2. Достаем pooled вектор (последний токен) [Batch, Dim]
|
| 87 |
-
seq_lens = toks.attention_mask.sum(dim=1) - 1
|
| 88 |
-
pooled = hidden[torch.arange(hidden.shape[0]), seq_lens.clamp(min=0)]
|
| 89 |
-
|
| 90 |
-
# 3. Нормализация
|
| 91 |
-
norm = self.text_encoder.text_model.final_layer_norm
|
| 92 |
-
hidden = norm(hidden)
|
| 93 |
-
pooled = norm(pooled)
|
| 94 |
-
|
| 95 |
-
# 4. Объединяем в матрицу: Пулед (как 1-й токен) + остальные токены
|
| 96 |
-
# pooled.unsqueeze(1) делает [Batch, 1, Dim]
|
| 97 |
-
embeds = torch.cat([pooled.unsqueeze(1), hidden], dim=1)
|
| 98 |
-
|
| 99 |
-
# 5. Расширяем маску для нового токена (добавляем единицы спереди)
|
| 100 |
-
ones = torch.ones((toks.attention_mask.shape[0], 1), dtype=toks.attention_mask.dtype, device=device)
|
| 101 |
-
mask = torch.cat([ones, toks.attention_mask], dim=1)
|
| 102 |
-
|
| 103 |
-
return embeds, mask, pooled
|
| 104 |
-
|
| 105 |
-
def get_pooled_encode(texts):
|
| 106 |
-
if texts is None:
|
| 107 |
-
texts = ""
|
| 108 |
-
|
| 109 |
-
if isinstance(texts, str):
|
| 110 |
-
texts = [texts]
|
| 111 |
-
|
| 112 |
-
with torch.no_grad():
|
| 113 |
-
# 1. Собираем текстовые промпты оборачивая их в Chat Template
|
| 114 |
-
formatted_prompts = []
|
| 115 |
-
for t in texts:
|
| 116 |
-
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 117 |
-
res_text = self.tokenizer2.apply_chat_template(
|
| 118 |
-
messages,
|
| 119 |
-
add_generation_prompt=True,
|
| 120 |
-
tokenize=False
|
| 121 |
-
)
|
| 122 |
-
formatted_prompts.append(res_text)
|
| 123 |
-
|
| 124 |
-
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 125 |
-
toks = self.tokenizer2(
|
| 126 |
-
formatted_prompts,
|
| 127 |
-
padding="max_length",
|
| 128 |
-
max_length=self.text_encoder.config.max_position_embeddings,
|
| 129 |
-
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 130 |
-
return_tensors="pt"
|
| 131 |
-
).to(device)
|
| 132 |
-
|
| 133 |
-
# 3. Прогоняем через модель
|
| 134 |
-
outputs = self.text_encoder2(
|
| 135 |
-
input_ids=toks.input_ids,
|
| 136 |
-
attention_mask=toks.attention_mask,
|
| 137 |
-
output_hidden_states=True
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
layer_index = -2
|
| 141 |
-
last_hidden = outputs.hidden_states[layer_index]
|
| 142 |
-
seq_len = toks.attention_mask.sum(dim=1) - 1
|
| 143 |
-
pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
|
| 144 |
-
|
| 145 |
-
return pooled
|
| 146 |
def get_encode(texts):
|
| 147 |
if texts is None:
|
| 148 |
texts = ""
|
|
@@ -155,7 +69,7 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 155 |
formatted_prompts = []
|
| 156 |
for t in texts:
|
| 157 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 158 |
-
res_text = self.
|
| 159 |
messages,
|
| 160 |
add_generation_prompt=True,
|
| 161 |
tokenize=False
|
|
@@ -163,16 +77,16 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 163 |
formatted_prompts.append(res_text)
|
| 164 |
|
| 165 |
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 166 |
-
toks = self.
|
| 167 |
formatted_prompts,
|
| 168 |
padding="max_length",
|
| 169 |
-
max_length=
|
| 170 |
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 171 |
return_tensors="pt"
|
| 172 |
).to(device)
|
| 173 |
|
| 174 |
# 3. Прогоняем через модель
|
| 175 |
-
outputs = self.
|
| 176 |
input_ids=toks.input_ids,
|
| 177 |
attention_mask=toks.attention_mask,
|
| 178 |
output_hidden_states=True
|
|
@@ -185,11 +99,6 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 185 |
|
| 186 |
return last_hidden, toks.attention_mask, pooled
|
| 187 |
|
| 188 |
-
#pos_embeds, pos_mask, pooled_pos = get_single_encode(prompt)
|
| 189 |
-
#neg_embeds, neg_mask, pooled_neg = get_single_encode(negative_prompt)
|
| 190 |
-
# 768 + 2048
|
| 191 |
-
#pos_pooled = get_pooled_encode(prompt) #torch.cat([pooled_pos, get_pooled_encode(prompt)], dim=1)
|
| 192 |
-
#neg_pooled = get_pooled_encode(negative_prompt) #torch.cat([pooled_neg, get_pooled_encode(negative_prompt)], dim=1)
|
| 193 |
pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
|
| 194 |
neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
|
| 195 |
|
|
@@ -223,7 +132,7 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 223 |
seed: Optional[int] = None,
|
| 224 |
output_type: str = "pil",
|
| 225 |
return_dict: bool = True,
|
| 226 |
-
refine_prompt: bool = False,
|
| 227 |
# structure_preservation оставляем для совместимости, но теперь он почти не нужен
|
| 228 |
structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
|
| 229 |
**kwargs,
|
|
|
|
| 7 |
from diffusers import DiffusionPipeline
|
| 8 |
from diffusers.utils import BaseOutput
|
| 9 |
from tqdm import tqdm
|
| 10 |
+
from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Tokenizer
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class SdxsPipelineOutput(BaseOutput):
|
| 14 |
images: Union[List[Image.Image], np.ndarray]
|
| 15 |
+
prompt: Optional[Union[str, List[str]]] = None
|
| 16 |
|
| 17 |
class SdxsPipeline(DiffusionPipeline):
|
| 18 |
+
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
|
| 19 |
super().__init__()
|
| 20 |
self.register_modules(
|
| 21 |
vae=vae,
|
| 22 |
text_encoder=text_encoder,
|
|
|
|
| 23 |
tokenizer=tokenizer,
|
|
|
|
| 24 |
unet=unet,
|
| 25 |
scheduler=scheduler
|
| 26 |
)
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def encode_prompt(self, prompt, negative_prompt, device, dtype):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def get_encode(texts):
|
| 61 |
if texts is None:
|
| 62 |
texts = ""
|
|
|
|
| 69 |
formatted_prompts = []
|
| 70 |
for t in texts:
|
| 71 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 72 |
+
res_text = self.tokenizer.apply_chat_template(
|
| 73 |
messages,
|
| 74 |
add_generation_prompt=True,
|
| 75 |
tokenize=False
|
|
|
|
| 77 |
formatted_prompts.append(res_text)
|
| 78 |
|
| 79 |
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 80 |
+
toks = self.tokenizer(
|
| 81 |
formatted_prompts,
|
| 82 |
padding="max_length",
|
| 83 |
+
max_length=255,
|
| 84 |
truncation=True, # Не забываем обрезать, если вдруг длиннее
|
| 85 |
return_tensors="pt"
|
| 86 |
).to(device)
|
| 87 |
|
| 88 |
# 3. Прогоняем через модель
|
| 89 |
+
outputs = self.text_encoder(
|
| 90 |
input_ids=toks.input_ids,
|
| 91 |
attention_mask=toks.attention_mask,
|
| 92 |
output_hidden_states=True
|
|
|
|
| 99 |
|
| 100 |
return last_hidden, toks.attention_mask, pooled
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
pos_embeds, pos_mask, pos_pooled = get_encode(prompt)
|
| 103 |
neg_embeds, neg_mask, neg_pooled = get_encode(negative_prompt)
|
| 104 |
|
|
|
|
| 132 |
seed: Optional[int] = None,
|
| 133 |
output_type: str = "pil",
|
| 134 |
return_dict: bool = True,
|
| 135 |
+
refine_prompt: bool = False,
|
| 136 |
# structure_preservation оставляем для совместимости, но теперь он почти не нужен
|
| 137 |
structure_preservation: float = 0.0, # 0.0 = стандартный линейный путь (лучше всего)
|
| 138 |
**kwargs,
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b1dc22c65c672008be6159c53028670acace729df6e26c3d1199a8929a07060
|
| 3 |
+
size 6130032
|
text_encoder/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:199bacf59248a05c934c618cacd62b6cc2f60e1637563f037206c2b09330a4ff
|
| 3 |
+
size 2613
|
{text_encoder2 → text_encoder}/generation_config.json
RENAMED
|
File without changes
|
text_encoder/model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:020f49a22fe87e482485f27e57a01782f1faf7d0f312cc740d83faac36babcba
|
| 3 |
+
size 4426558248
|
text_encoder2/config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:199bacf59248a05c934c618cacd62b6cc2f60e1637563f037206c2b09330a4ff
|
| 3 |
-
size 2613
|
|
|
|
|
|
|
|
|
|
|
|
text_encoder2/model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:020f49a22fe87e482485f27e57a01782f1faf7d0f312cc740d83faac36babcba
|
| 3 |
-
size 4426558248
|
|
|
|
|
|
|
|
|
|
|
|
tmp/config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7d9d2485cfbdfe8c4a3d350907422ade854ffa1c8b06bd39b896ff45fc444cda
|
| 3 |
-
size 1858
|
|
|
|
|
|
|
|
|
|
|
|
tmp/diffusion_pytorch_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4ef8fd48bd3075eae7202915c3b05b3ed8e90c70c41625f3f70a2d868a1013c9
|
| 3 |
-
size 3159550424
|
|
|
|
|
|
|
|
|
|
|
|
{tokenizer2 → tokenizer}/chat_template.jinja
RENAMED
|
File without changes
|
tokenizer/merges.txt
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{tokenizer2 → tokenizer}/preprocessor_config.json
RENAMED
|
File without changes
|
{tokenizer2 → tokenizer}/processor_config.json
RENAMED
|
File without changes
|
tokenizer/special_tokens_map.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2cdb3b8331a60c92fc1e55a13e9fd61fd2293c5a51275fdcccd62b780052530e
|
| 3 |
-
size 588
|
|
|
|
|
|
|
|
|
|
|
|
{tokenizer2 → tokenizer}/tokenizer.json
RENAMED
|
File without changes
|
tokenizer/tokenizer_config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4f52c2103a685c3d9f3022b3954246fbab865abd709ac69d8b8a98f79580564
|
| 3 |
+
size 1140
|
tokenizer/vocab.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce99b4cb2983d118806ce0a8b777a35b093e2000a503ebde25853284c9dfa003
|
| 3 |
+
size 6722759
|
tokenizer2/merges.txt
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer2/tokenizer_config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a4f52c2103a685c3d9f3022b3954246fbab865abd709ac69d8b8a98f79580564
|
| 3 |
-
size 1140
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer2/vocab.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ce99b4cb2983d118806ce0a8b777a35b093e2000a503ebde25853284c9dfa003
|
| 3 |
-
size 6722759
|
|
|
|
|
|
|
|
|
|
|
|
train.py
CHANGED
|
@@ -149,10 +149,8 @@ if accelerator.is_main_process:
|
|
| 149 |
# --------------------------- Загрузка моделей ---------------------------
|
| 150 |
#vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
|
| 151 |
vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
|
| 152 |
-
tokenizer =
|
| 153 |
-
text_encoder =
|
| 154 |
-
tokenizer2 = Qwen3_5Tokenizer.from_pretrained("tokenizer2")
|
| 155 |
-
text_encoder2 = Qwen3_5ForConditionalGeneration.from_pretrained("text_encoder2", torch_dtype=torch.float16).to(device).eval()
|
| 156 |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
|
| 157 |
|
| 158 |
def encode_texts(texts, max_length=max_length):
|
|
@@ -168,7 +166,7 @@ def encode_texts(texts, max_length=max_length):
|
|
| 168 |
formatted_prompts = []
|
| 169 |
for t in texts:
|
| 170 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 171 |
-
res_text =
|
| 172 |
messages,
|
| 173 |
add_generation_prompt=True,
|
| 174 |
tokenize=False
|
|
@@ -176,7 +174,7 @@ def encode_texts(texts, max_length=max_length):
|
|
| 176 |
formatted_prompts.append(res_text)
|
| 177 |
|
| 178 |
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 179 |
-
toks =
|
| 180 |
formatted_prompts,
|
| 181 |
padding="max_length",
|
| 182 |
max_length=max_length,
|
|
@@ -185,7 +183,7 @@ def encode_texts(texts, max_length=max_length):
|
|
| 185 |
).to(device)
|
| 186 |
|
| 187 |
# 3. Прогоняем через модель
|
| 188 |
-
outputs =
|
| 189 |
input_ids=toks.input_ids,
|
| 190 |
attention_mask=toks.attention_mask,
|
| 191 |
output_hidden_states=True
|
|
|
|
| 149 |
# --------------------------- Загрузка моделей ---------------------------
|
| 150 |
#vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
|
| 151 |
vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
|
| 152 |
+
tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer2")
|
| 153 |
+
text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained("text_encoder2", torch_dtype=torch.float16).to(device).eval()
|
|
|
|
|
|
|
| 154 |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
|
| 155 |
|
| 156 |
def encode_texts(texts, max_length=max_length):
|
|
|
|
| 166 |
formatted_prompts = []
|
| 167 |
for t in texts:
|
| 168 |
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
|
| 169 |
+
res_text = tokenizer.apply_chat_template(
|
| 170 |
messages,
|
| 171 |
add_generation_prompt=True,
|
| 172 |
tokenize=False
|
|
|
|
| 174 |
formatted_prompts.append(res_text)
|
| 175 |
|
| 176 |
# 2. Токенизируем, режем и добавляем паддинг за один раз
|
| 177 |
+
toks = tokenizer(
|
| 178 |
formatted_prompts,
|
| 179 |
padding="max_length",
|
| 180 |
max_length=max_length,
|
|
|
|
| 183 |
).to(device)
|
| 184 |
|
| 185 |
# 3. Прогоняем через модель
|
| 186 |
+
outputs = text_encoder(
|
| 187 |
input_ids=toks.input_ids,
|
| 188 |
attention_mask=toks.attention_mask,
|
| 189 |
output_hidden_states=True
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 6318956752
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1a8e5563a08991becd74bb41292f800b383fb757cafe97ce5b87b6004dfb87a
|
| 3 |
size 6318956752
|
unet_old/config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e0c1f603d4dfb8d759010daf7d83df29ff148ba48693400fb14c0f4dbd9b7d2f
|
| 3 |
-
size 1884
|
|
|
|
|
|
|
|
|
|
|
|
unet_old/diffusion_pytorch_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4fc036a58605f5073498525bf4f8af44552ad1f57fcb7f545863d45bf95d5ce2
|
| 3 |
-
size 5960474736
|
|
|
|
|
|
|
|
|
|
|
|