Upload folder using huggingface_hub
Browse files- artflow/models/dit_blocks.py +5 -4
- artflow/pipeline/artflow_pipeline.py +59 -29
- model.safetensors +1 -1
artflow/models/dit_blocks.py
CHANGED
|
@@ -176,8 +176,8 @@ class MSRoPE(nn.Module):
|
|
| 176 |
|
| 177 |
# Text frequencies start after maximum image position
|
| 178 |
max_img_pos = max(height, width)
|
| 179 |
-
# View as complex for slicing
|
| 180 |
-
pos_freqs_complex = torch.view_as_complex(self.pos_freqs)
|
| 181 |
txt_freqs = pos_freqs_complex[
|
| 182 |
max_img_pos : max_img_pos + txt_seq_len, :
|
| 183 |
] # placing text tokens on a diagonal in the 2D position space
|
|
@@ -195,9 +195,10 @@ class MSRoPE(nn.Module):
|
|
| 195 |
Frequency tensor [height*width, total_dim] (complex)
|
| 196 |
"""
|
| 197 |
# Split precomputed frequencies by axis
|
| 198 |
-
# pos_freqs is [S, D, 2] (real)
|
| 199 |
h_dim, w_dim = self.axes_dim
|
| 200 |
-
|
|
|
|
| 201 |
|
| 202 |
# Select frequencies for the current height and width
|
| 203 |
h_freqs = h_freqs[:height, :] # [H, h_dim//2, 2]
|
|
|
|
| 176 |
|
| 177 |
# Text frequencies start after maximum image position
|
| 178 |
max_img_pos = max(height, width)
|
| 179 |
+
# View as complex for slicing (must be float32 — view_as_complex doesn't support bf16)
|
| 180 |
+
pos_freqs_complex = torch.view_as_complex(self.pos_freqs.float())
|
| 181 |
txt_freqs = pos_freqs_complex[
|
| 182 |
max_img_pos : max_img_pos + txt_seq_len, :
|
| 183 |
] # placing text tokens on a diagonal in the 2D position space
|
|
|
|
| 195 |
Frequency tensor [height*width, total_dim] (complex)
|
| 196 |
"""
|
| 197 |
# Split precomputed frequencies by axis
|
| 198 |
+
# pos_freqs is [S, D, 2] (real) — cast to float32 for view_as_complex
|
| 199 |
h_dim, w_dim = self.axes_dim
|
| 200 |
+
pos_freqs = self.pos_freqs.float()
|
| 201 |
+
h_freqs, w_freqs = pos_freqs.split([h_dim // 2, w_dim // 2], dim=1)
|
| 202 |
|
| 203 |
# Select frequencies for the current height and width
|
| 204 |
h_freqs = h_freqs[:height, :] # [H, h_dim//2, 2]
|
artflow/pipeline/artflow_pipeline.py
CHANGED
|
@@ -49,6 +49,8 @@ class ArtFlowPipeline:
|
|
| 49 |
vae_std: Optional[torch.Tensor] = None,
|
| 50 |
solver: str = "euler",
|
| 51 |
dtype: torch.dtype = torch.bfloat16,
|
|
|
|
|
|
|
| 52 |
):
|
| 53 |
self.transformer = transformer
|
| 54 |
self.vae = vae
|
|
@@ -58,15 +60,20 @@ class ArtFlowPipeline:
|
|
| 58 |
self.vae_std = vae_std
|
| 59 |
self.solver = solver
|
| 60 |
self.dtype = dtype
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Move to eval mode
|
| 63 |
self.transformer.eval()
|
| 64 |
self.vae.eval()
|
| 65 |
self.text_encoder.eval()
|
| 66 |
|
| 67 |
-
def
|
| 68 |
-
"""Get
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@classmethod
|
| 72 |
def from_pretrained(
|
|
@@ -91,6 +98,7 @@ class ArtFlowPipeline:
|
|
| 91 |
|
| 92 |
dtype = kwargs.get("dtype", torch.bfloat16)
|
| 93 |
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 94 |
|
| 95 |
# Download all source files from the repo
|
| 96 |
repo_files = list_repo_files(pretrained_model_name_or_path)
|
|
@@ -132,18 +140,17 @@ class ArtFlowPipeline:
|
|
| 132 |
state_dict = state_dict["module"]
|
| 133 |
|
| 134 |
transformer.load_state_dict(state_dict)
|
| 135 |
-
transformer.to(device=device, dtype=dtype)
|
| 136 |
|
| 137 |
# Load VAE
|
| 138 |
vae_repo = config.get("vae_repo", "REPA-E/e2e-qwenimage-vae")
|
| 139 |
-
vae = AutoencoderKLQwenImage.from_pretrained(vae_repo, torch_dtype=dtype)
|
| 140 |
|
| 141 |
# Load text encoder
|
| 142 |
text_encoder_repo = config.get("text_encoder_repo", "Qwen/Qwen3-0.6B")
|
| 143 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 144 |
text_encoder_repo,
|
| 145 |
dtype=dtype,
|
| 146 |
-
|
| 147 |
)
|
| 148 |
tokenizer = AutoTokenizer.from_pretrained(text_encoder_repo)
|
| 149 |
|
|
@@ -152,6 +159,16 @@ class ArtFlowPipeline:
|
|
| 152 |
vae_mean = vae_mean.to(device=device, dtype=dtype)
|
| 153 |
vae_std = vae_std.to(device=device, dtype=dtype)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Create pipeline
|
| 156 |
pipe = cls(
|
| 157 |
transformer=transformer,
|
|
@@ -162,6 +179,8 @@ class ArtFlowPipeline:
|
|
| 162 |
vae_std=vae_std,
|
| 163 |
solver=config.get("solver", "euler"),
|
| 164 |
dtype=dtype,
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
return pipe
|
|
@@ -220,11 +239,12 @@ class ArtFlowPipeline:
|
|
| 220 |
prompt = [prompt]
|
| 221 |
batch_size = len(prompt)
|
| 222 |
|
| 223 |
-
#
|
| 224 |
do_cfg = guidance_scale > 1.0
|
| 225 |
|
|
|
|
|
|
|
| 226 |
if do_cfg:
|
| 227 |
-
# Encode prompts and negative prompts together
|
| 228 |
if negative_prompt is None:
|
| 229 |
negative_prompt = [""] * batch_size
|
| 230 |
elif isinstance(negative_prompt, str):
|
|
@@ -233,7 +253,6 @@ class ArtFlowPipeline:
|
|
| 233 |
all_prompts = prompt + negative_prompt
|
| 234 |
text_emb, text_mask, _ = self._encode_text(all_prompts)
|
| 235 |
|
| 236 |
-
# Split into conditional and unconditional
|
| 237 |
text_emb_cond = text_emb[:batch_size]
|
| 238 |
text_emb_uncond = text_emb[batch_size:]
|
| 239 |
text_mask_cond = text_mask[:batch_size]
|
|
@@ -242,9 +261,14 @@ class ArtFlowPipeline:
|
|
| 242 |
text_emb_cond, text_mask_cond = self._encode_text(prompt)[:2]
|
| 243 |
text_emb_uncond = None
|
| 244 |
text_mask_uncond = None
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
generator = torch.Generator(device=device)
|
| 249 |
if seed is not None:
|
| 250 |
generator.manual_seed(seed)
|
|
@@ -257,7 +281,6 @@ class ArtFlowPipeline:
|
|
| 257 |
)
|
| 258 |
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=self.dtype)
|
| 259 |
|
| 260 |
-
# Prepare for CFG - repeat latents if doing CFG
|
| 261 |
if do_cfg:
|
| 262 |
latents = torch.cat([latents, latents], dim=0)
|
| 263 |
text_emb = torch.cat([text_emb_cond, text_emb_uncond], dim=0)
|
|
@@ -266,32 +289,38 @@ class ArtFlowPipeline:
|
|
| 266 |
text_emb = text_emb_cond
|
| 267 |
text_mask = text_mask_cond
|
| 268 |
|
| 269 |
-
# Denoise
|
| 270 |
def model_fn(x, t):
|
| 271 |
-
t_tensor = torch.
|
| 272 |
return self.transformer(x, t_tensor, text_emb, txt_mask=text_mask)
|
| 273 |
|
| 274 |
-
# Run solver
|
| 275 |
from artflow.flow.solvers import sample_ode
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
| 284 |
|
| 285 |
-
# Apply CFG if enabled
|
| 286 |
if do_cfg:
|
| 287 |
latents_cond, latents_uncond = latents.chunk(2)
|
| 288 |
latents = latents_uncond + guidance_scale * (latents_cond - latents_uncond)
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
#
|
| 291 |
if output_type == "latent":
|
| 292 |
images = latents
|
| 293 |
else:
|
|
|
|
|
|
|
| 294 |
images = self._decode_latents(latents)
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
if not return_dict:
|
| 297 |
return images
|
|
@@ -309,16 +338,17 @@ class ArtFlowPipeline:
|
|
| 309 |
prompts, self.text_encoder, self.tokenizer, pooling=pooling
|
| 310 |
)
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
txt_mask = txt_mask.to(device=device)
|
| 315 |
if txt_pooled is not None:
|
| 316 |
-
txt_pooled = txt_pooled.to(device=device, dtype=self.dtype)
|
| 317 |
|
| 318 |
return txt_emb, txt_mask, txt_pooled
|
| 319 |
|
| 320 |
def _decode_latents(self, latents: torch.Tensor) -> List[Image.Image]:
|
| 321 |
"""Decode VAE latents to PIL images."""
|
|
|
|
|
|
|
| 322 |
# Denormalize
|
| 323 |
if self.vae_mean is not None and self.vae_std is not None:
|
| 324 |
latents = latents * self.vae_std + self.vae_mean
|
|
|
|
| 49 |
vae_std: Optional[torch.Tensor] = None,
|
| 50 |
solver: str = "euler",
|
| 51 |
dtype: torch.dtype = torch.bfloat16,
|
| 52 |
+
device: Optional[str] = None,
|
| 53 |
+
offload: bool = True,
|
| 54 |
):
|
| 55 |
self.transformer = transformer
|
| 56 |
self.vae = vae
|
|
|
|
| 60 |
self.vae_std = vae_std
|
| 61 |
self.solver = solver
|
| 62 |
self.dtype = dtype
|
| 63 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 64 |
+
self.offload = offload
|
| 65 |
|
| 66 |
# Move to eval mode
|
| 67 |
self.transformer.eval()
|
| 68 |
self.vae.eval()
|
| 69 |
self.text_encoder.eval()
|
| 70 |
|
| 71 |
+
def _get_autocast_context(self):
|
| 72 |
+
"""Get autocast context manager for inference."""
|
| 73 |
+
device_type = "cuda" if "cuda" in self.device else "cpu"
|
| 74 |
+
if self.dtype in (torch.float16, torch.bfloat16):
|
| 75 |
+
return torch.autocast(device_type=device_type, dtype=self.dtype)
|
| 76 |
+
return torch.no_grad()
|
| 77 |
|
| 78 |
@classmethod
|
| 79 |
def from_pretrained(
|
|
|
|
| 98 |
|
| 99 |
dtype = kwargs.get("dtype", torch.bfloat16)
|
| 100 |
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
+
offload = kwargs.get("offload", True)
|
| 102 |
|
| 103 |
# Download all source files from the repo
|
| 104 |
repo_files = list_repo_files(pretrained_model_name_or_path)
|
|
|
|
| 140 |
state_dict = state_dict["module"]
|
| 141 |
|
| 142 |
transformer.load_state_dict(state_dict)
|
|
|
|
| 143 |
|
| 144 |
# Load VAE
|
| 145 |
vae_repo = config.get("vae_repo", "REPA-E/e2e-qwenimage-vae")
|
| 146 |
+
vae = AutoencoderKLQwenImage.from_pretrained(vae_repo, torch_dtype=dtype)
|
| 147 |
|
| 148 |
# Load text encoder
|
| 149 |
text_encoder_repo = config.get("text_encoder_repo", "Qwen/Qwen3-0.6B")
|
| 150 |
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 151 |
text_encoder_repo,
|
| 152 |
dtype=dtype,
|
| 153 |
+
low_cpu_mem_usage=True,
|
| 154 |
)
|
| 155 |
tokenizer = AutoTokenizer.from_pretrained(text_encoder_repo)
|
| 156 |
|
|
|
|
| 159 |
vae_mean = vae_mean.to(device=device, dtype=dtype)
|
| 160 |
vae_std = vae_std.to(device=device, dtype=dtype)
|
| 161 |
|
| 162 |
+
# Load models to appropriate device based on offload setting
|
| 163 |
+
if offload:
|
| 164 |
+
# Keep on CPU, offload to GPU when needed
|
| 165 |
+
transformer.to(dtype=dtype)
|
| 166 |
+
else:
|
| 167 |
+
# Load directly to GPU
|
| 168 |
+
transformer.to(device=device, dtype=dtype)
|
| 169 |
+
vae.to(device=device)
|
| 170 |
+
text_encoder.to(device=device)
|
| 171 |
+
|
| 172 |
# Create pipeline
|
| 173 |
pipe = cls(
|
| 174 |
transformer=transformer,
|
|
|
|
| 179 |
vae_std=vae_std,
|
| 180 |
solver=config.get("solver", "euler"),
|
| 181 |
dtype=dtype,
|
| 182 |
+
device=device,
|
| 183 |
+
offload=offload,
|
| 184 |
)
|
| 185 |
|
| 186 |
return pipe
|
|
|
|
| 239 |
prompt = [prompt]
|
| 240 |
batch_size = len(prompt)
|
| 241 |
|
| 242 |
+
# --- Stage 1: Text encoding (text_encoder on GPU) ---
|
| 243 |
do_cfg = guidance_scale > 1.0
|
| 244 |
|
| 245 |
+
if self.offload:
|
| 246 |
+
self.text_encoder.to(self.device)
|
| 247 |
if do_cfg:
|
|
|
|
| 248 |
if negative_prompt is None:
|
| 249 |
negative_prompt = [""] * batch_size
|
| 250 |
elif isinstance(negative_prompt, str):
|
|
|
|
| 253 |
all_prompts = prompt + negative_prompt
|
| 254 |
text_emb, text_mask, _ = self._encode_text(all_prompts)
|
| 255 |
|
|
|
|
| 256 |
text_emb_cond = text_emb[:batch_size]
|
| 257 |
text_emb_uncond = text_emb[batch_size:]
|
| 258 |
text_mask_cond = text_mask[:batch_size]
|
|
|
|
| 261 |
text_emb_cond, text_mask_cond = self._encode_text(prompt)[:2]
|
| 262 |
text_emb_uncond = None
|
| 263 |
text_mask_uncond = None
|
| 264 |
+
if self.offload:
|
| 265 |
+
self.text_encoder.to("cpu")
|
| 266 |
+
torch.cuda.empty_cache()
|
| 267 |
+
|
| 268 |
+
# --- Stage 2: Denoising (transformer on GPU) ---
|
| 269 |
+
if self.offload:
|
| 270 |
+
self.transformer.to(self.device)
|
| 271 |
+
device = torch.device(self.device)
|
| 272 |
generator = torch.Generator(device=device)
|
| 273 |
if seed is not None:
|
| 274 |
generator.manual_seed(seed)
|
|
|
|
| 281 |
)
|
| 282 |
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=self.dtype)
|
| 283 |
|
|
|
|
| 284 |
if do_cfg:
|
| 285 |
latents = torch.cat([latents, latents], dim=0)
|
| 286 |
text_emb = torch.cat([text_emb_cond, text_emb_uncond], dim=0)
|
|
|
|
| 289 |
text_emb = text_emb_cond
|
| 290 |
text_mask = text_mask_cond
|
| 291 |
|
|
|
|
| 292 |
def model_fn(x, t):
|
| 293 |
+
t_tensor = torch.as_tensor(t, device=x.device).expand(x.shape[0])
|
| 294 |
return self.transformer(x, t_tensor, text_emb, txt_mask=text_mask)
|
| 295 |
|
|
|
|
| 296 |
from artflow.flow.solvers import sample_ode
|
| 297 |
|
| 298 |
+
with self._get_autocast_context():
|
| 299 |
+
latents = sample_ode(
|
| 300 |
+
model_fn,
|
| 301 |
+
latents,
|
| 302 |
+
steps=num_inference_steps,
|
| 303 |
+
solver=solver,
|
| 304 |
+
device=self.device,
|
| 305 |
+
)
|
| 306 |
|
|
|
|
| 307 |
if do_cfg:
|
| 308 |
latents_cond, latents_uncond = latents.chunk(2)
|
| 309 |
latents = latents_uncond + guidance_scale * (latents_cond - latents_uncond)
|
| 310 |
+
if self.offload:
|
| 311 |
+
self.transformer.to("cpu")
|
| 312 |
+
torch.cuda.empty_cache()
|
| 313 |
|
| 314 |
+
# --- Stage 3: VAE decode (vae on GPU) ---
|
| 315 |
if output_type == "latent":
|
| 316 |
images = latents
|
| 317 |
else:
|
| 318 |
+
if self.offload:
|
| 319 |
+
self.vae.to(self.device)
|
| 320 |
images = self._decode_latents(latents)
|
| 321 |
+
if self.offload:
|
| 322 |
+
self.vae.to("cpu")
|
| 323 |
+
torch.cuda.empty_cache()
|
| 324 |
|
| 325 |
if not return_dict:
|
| 326 |
return images
|
|
|
|
| 338 |
prompts, self.text_encoder, self.tokenizer, pooling=pooling
|
| 339 |
)
|
| 340 |
|
| 341 |
+
txt_emb = txt_emb.to(device=self.device, dtype=self.dtype)
|
| 342 |
+
txt_mask = txt_mask.to(device=self.device)
|
|
|
|
| 343 |
if txt_pooled is not None:
|
| 344 |
+
txt_pooled = txt_pooled.to(device=self.device, dtype=self.dtype)
|
| 345 |
|
| 346 |
return txt_emb, txt_mask, txt_pooled
|
| 347 |
|
| 348 |
def _decode_latents(self, latents: torch.Tensor) -> List[Image.Image]:
|
| 349 |
"""Decode VAE latents to PIL images."""
|
| 350 |
+
latents = latents.to(device=self.device, dtype=self.dtype)
|
| 351 |
+
|
| 352 |
# Denormalize
|
| 353 |
if self.vae_mean is not None and self.vae_std is not None:
|
| 354 |
latents = latents * self.vae_std + self.vae_mean
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2715352432
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:619009befe1afec50e1936d18afd3bbb8b7d314265c7975110cd8b17aee49fad
|
| 3 |
size 2715352432
|