kaupane commited on
Commit
6fc0285
·
verified ·
1 Parent(s): fe1a3e3

Upload folder using huggingface_hub

Browse files
artflow/models/dit_blocks.py CHANGED
@@ -176,8 +176,8 @@ class MSRoPE(nn.Module):
176
 
177
  # Text frequencies start after maximum image position
178
  max_img_pos = max(height, width)
179
- # View as complex for slicing
180
- pos_freqs_complex = torch.view_as_complex(self.pos_freqs)
181
  txt_freqs = pos_freqs_complex[
182
  max_img_pos : max_img_pos + txt_seq_len, :
183
  ] # placing text tokens on a diagonal in the 2D position space
@@ -195,9 +195,10 @@ class MSRoPE(nn.Module):
195
  Frequency tensor [height*width, total_dim] (complex)
196
  """
197
  # Split precomputed frequencies by axis
198
- # pos_freqs is [S, D, 2] (real)
199
  h_dim, w_dim = self.axes_dim
200
- h_freqs, w_freqs = self.pos_freqs.split([h_dim // 2, w_dim // 2], dim=1)
 
201
 
202
  # Select frequencies for the current height and width
203
  h_freqs = h_freqs[:height, :] # [H, h_dim//2, 2]
 
176
 
177
  # Text frequencies start after maximum image position
178
  max_img_pos = max(height, width)
179
+ # View as complex for slicing (must be float32 — view_as_complex doesn't support bf16)
180
+ pos_freqs_complex = torch.view_as_complex(self.pos_freqs.float())
181
  txt_freqs = pos_freqs_complex[
182
  max_img_pos : max_img_pos + txt_seq_len, :
183
  ] # placing text tokens on a diagonal in the 2D position space
 
195
  Frequency tensor [height*width, total_dim] (complex)
196
  """
197
  # Split precomputed frequencies by axis
198
+ # pos_freqs is [S, D, 2] (real) — cast to float32 for view_as_complex
199
  h_dim, w_dim = self.axes_dim
200
+ pos_freqs = self.pos_freqs.float()
201
+ h_freqs, w_freqs = pos_freqs.split([h_dim // 2, w_dim // 2], dim=1)
202
 
203
  # Select frequencies for the current height and width
204
  h_freqs = h_freqs[:height, :] # [H, h_dim//2, 2]
artflow/pipeline/artflow_pipeline.py CHANGED
@@ -49,6 +49,8 @@ class ArtFlowPipeline:
49
  vae_std: Optional[torch.Tensor] = None,
50
  solver: str = "euler",
51
  dtype: torch.dtype = torch.bfloat16,
 
 
52
  ):
53
  self.transformer = transformer
54
  self.vae = vae
@@ -58,15 +60,20 @@ class ArtFlowPipeline:
58
  self.vae_std = vae_std
59
  self.solver = solver
60
  self.dtype = dtype
 
 
61
 
62
  # Move to eval mode
63
  self.transformer.eval()
64
  self.vae.eval()
65
  self.text_encoder.eval()
66
 
67
- def _get_device(self) -> torch.device:
68
- """Get the device of the transformer model."""
69
- return next(self.transformer.parameters()).device
 
 
 
70
 
71
  @classmethod
72
  def from_pretrained(
@@ -91,6 +98,7 @@ class ArtFlowPipeline:
91
 
92
  dtype = kwargs.get("dtype", torch.bfloat16)
93
  device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
 
94
 
95
  # Download all source files from the repo
96
  repo_files = list_repo_files(pretrained_model_name_or_path)
@@ -132,18 +140,17 @@ class ArtFlowPipeline:
132
  state_dict = state_dict["module"]
133
 
134
  transformer.load_state_dict(state_dict)
135
- transformer.to(device=device, dtype=dtype)
136
 
137
  # Load VAE
138
  vae_repo = config.get("vae_repo", "REPA-E/e2e-qwenimage-vae")
139
- vae = AutoencoderKLQwenImage.from_pretrained(vae_repo, torch_dtype=dtype).to(device)
140
 
141
  # Load text encoder
142
  text_encoder_repo = config.get("text_encoder_repo", "Qwen/Qwen3-0.6B")
143
  text_encoder = AutoModelForCausalLM.from_pretrained(
144
  text_encoder_repo,
145
  dtype=dtype,
146
- device_map=device,
147
  )
148
  tokenizer = AutoTokenizer.from_pretrained(text_encoder_repo)
149
 
@@ -152,6 +159,16 @@ class ArtFlowPipeline:
152
  vae_mean = vae_mean.to(device=device, dtype=dtype)
153
  vae_std = vae_std.to(device=device, dtype=dtype)
154
 
 
 
 
 
 
 
 
 
 
 
155
  # Create pipeline
156
  pipe = cls(
157
  transformer=transformer,
@@ -162,6 +179,8 @@ class ArtFlowPipeline:
162
  vae_std=vae_std,
163
  solver=config.get("solver", "euler"),
164
  dtype=dtype,
 
 
165
  )
166
 
167
  return pipe
@@ -220,11 +239,12 @@ class ArtFlowPipeline:
220
  prompt = [prompt]
221
  batch_size = len(prompt)
222
 
223
- # Encode text (with CFG support)
224
  do_cfg = guidance_scale > 1.0
225
 
 
 
226
  if do_cfg:
227
- # Encode prompts and negative prompts together
228
  if negative_prompt is None:
229
  negative_prompt = [""] * batch_size
230
  elif isinstance(negative_prompt, str):
@@ -233,7 +253,6 @@ class ArtFlowPipeline:
233
  all_prompts = prompt + negative_prompt
234
  text_emb, text_mask, _ = self._encode_text(all_prompts)
235
 
236
- # Split into conditional and unconditional
237
  text_emb_cond = text_emb[:batch_size]
238
  text_emb_uncond = text_emb[batch_size:]
239
  text_mask_cond = text_mask[:batch_size]
@@ -242,9 +261,14 @@ class ArtFlowPipeline:
242
  text_emb_cond, text_mask_cond = self._encode_text(prompt)[:2]
243
  text_emb_uncond = None
244
  text_mask_uncond = None
245
-
246
- # Generate latents
247
- device = self._get_device()
 
 
 
 
 
248
  generator = torch.Generator(device=device)
249
  if seed is not None:
250
  generator.manual_seed(seed)
@@ -257,7 +281,6 @@ class ArtFlowPipeline:
257
  )
258
  latents = torch.randn(latents_shape, generator=generator, device=device, dtype=self.dtype)
259
 
260
- # Prepare for CFG - repeat latents if doing CFG
261
  if do_cfg:
262
  latents = torch.cat([latents, latents], dim=0)
263
  text_emb = torch.cat([text_emb_cond, text_emb_uncond], dim=0)
@@ -266,32 +289,38 @@ class ArtFlowPipeline:
266
  text_emb = text_emb_cond
267
  text_mask = text_mask_cond
268
 
269
- # Denoise
270
  def model_fn(x, t):
271
- t_tensor = torch.tensor(t, device=x.device).expand(x.shape[0])
272
  return self.transformer(x, t_tensor, text_emb, txt_mask=text_mask)
273
 
274
- # Run solver
275
  from artflow.flow.solvers import sample_ode
276
 
277
- latents = sample_ode(
278
- model_fn,
279
- latents,
280
- steps=num_inference_steps,
281
- solver=solver,
282
- device=str(self._get_device()),
283
- )
 
284
 
285
- # Apply CFG if enabled
286
  if do_cfg:
287
  latents_cond, latents_uncond = latents.chunk(2)
288
  latents = latents_uncond + guidance_scale * (latents_cond - latents_uncond)
 
 
 
289
 
290
- # Decode
291
  if output_type == "latent":
292
  images = latents
293
  else:
 
 
294
  images = self._decode_latents(latents)
 
 
 
295
 
296
  if not return_dict:
297
  return images
@@ -309,16 +338,17 @@ class ArtFlowPipeline:
309
  prompts, self.text_encoder, self.tokenizer, pooling=pooling
310
  )
311
 
312
- device = self._get_device()
313
- txt_emb = txt_emb.to(device=device, dtype=self.dtype)
314
- txt_mask = txt_mask.to(device=device)
315
  if txt_pooled is not None:
316
- txt_pooled = txt_pooled.to(device=device, dtype=self.dtype)
317
 
318
  return txt_emb, txt_mask, txt_pooled
319
 
320
  def _decode_latents(self, latents: torch.Tensor) -> List[Image.Image]:
321
  """Decode VAE latents to PIL images."""
 
 
322
  # Denormalize
323
  if self.vae_mean is not None and self.vae_std is not None:
324
  latents = latents * self.vae_std + self.vae_mean
 
49
  vae_std: Optional[torch.Tensor] = None,
50
  solver: str = "euler",
51
  dtype: torch.dtype = torch.bfloat16,
52
+ device: Optional[str] = None,
53
+ offload: bool = True,
54
  ):
55
  self.transformer = transformer
56
  self.vae = vae
 
60
  self.vae_std = vae_std
61
  self.solver = solver
62
  self.dtype = dtype
63
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
64
+ self.offload = offload
65
 
66
  # Move to eval mode
67
  self.transformer.eval()
68
  self.vae.eval()
69
  self.text_encoder.eval()
70
 
71
+ def _get_autocast_context(self):
72
+ """Get autocast context manager for inference."""
73
+ device_type = "cuda" if "cuda" in self.device else "cpu"
74
+ if self.dtype in (torch.float16, torch.bfloat16):
75
+ return torch.autocast(device_type=device_type, dtype=self.dtype)
76
+ return torch.no_grad()
77
 
78
  @classmethod
79
  def from_pretrained(
 
98
 
99
  dtype = kwargs.get("dtype", torch.bfloat16)
100
  device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
101
+ offload = kwargs.get("offload", True)
102
 
103
  # Download all source files from the repo
104
  repo_files = list_repo_files(pretrained_model_name_or_path)
 
140
  state_dict = state_dict["module"]
141
 
142
  transformer.load_state_dict(state_dict)
 
143
 
144
  # Load VAE
145
  vae_repo = config.get("vae_repo", "REPA-E/e2e-qwenimage-vae")
146
+ vae = AutoencoderKLQwenImage.from_pretrained(vae_repo, torch_dtype=dtype)
147
 
148
  # Load text encoder
149
  text_encoder_repo = config.get("text_encoder_repo", "Qwen/Qwen3-0.6B")
150
  text_encoder = AutoModelForCausalLM.from_pretrained(
151
  text_encoder_repo,
152
  dtype=dtype,
153
+ low_cpu_mem_usage=True,
154
  )
155
  tokenizer = AutoTokenizer.from_pretrained(text_encoder_repo)
156
 
 
159
  vae_mean = vae_mean.to(device=device, dtype=dtype)
160
  vae_std = vae_std.to(device=device, dtype=dtype)
161
 
162
+ # Load models to appropriate device based on offload setting
163
+ if offload:
164
+ # Keep on CPU, offload to GPU when needed
165
+ transformer.to(dtype=dtype)
166
+ else:
167
+ # Load directly to GPU
168
+ transformer.to(device=device, dtype=dtype)
169
+ vae.to(device=device)
170
+ text_encoder.to(device=device)
171
+
172
  # Create pipeline
173
  pipe = cls(
174
  transformer=transformer,
 
179
  vae_std=vae_std,
180
  solver=config.get("solver", "euler"),
181
  dtype=dtype,
182
+ device=device,
183
+ offload=offload,
184
  )
185
 
186
  return pipe
 
239
  prompt = [prompt]
240
  batch_size = len(prompt)
241
 
242
+ # --- Stage 1: Text encoding (text_encoder on GPU) ---
243
  do_cfg = guidance_scale > 1.0
244
 
245
+ if self.offload:
246
+ self.text_encoder.to(self.device)
247
  if do_cfg:
 
248
  if negative_prompt is None:
249
  negative_prompt = [""] * batch_size
250
  elif isinstance(negative_prompt, str):
 
253
  all_prompts = prompt + negative_prompt
254
  text_emb, text_mask, _ = self._encode_text(all_prompts)
255
 
 
256
  text_emb_cond = text_emb[:batch_size]
257
  text_emb_uncond = text_emb[batch_size:]
258
  text_mask_cond = text_mask[:batch_size]
 
261
  text_emb_cond, text_mask_cond = self._encode_text(prompt)[:2]
262
  text_emb_uncond = None
263
  text_mask_uncond = None
264
+ if self.offload:
265
+ self.text_encoder.to("cpu")
266
+ torch.cuda.empty_cache()
267
+
268
+ # --- Stage 2: Denoising (transformer on GPU) ---
269
+ if self.offload:
270
+ self.transformer.to(self.device)
271
+ device = torch.device(self.device)
272
  generator = torch.Generator(device=device)
273
  if seed is not None:
274
  generator.manual_seed(seed)
 
281
  )
282
  latents = torch.randn(latents_shape, generator=generator, device=device, dtype=self.dtype)
283
 
 
284
  if do_cfg:
285
  latents = torch.cat([latents, latents], dim=0)
286
  text_emb = torch.cat([text_emb_cond, text_emb_uncond], dim=0)
 
289
  text_emb = text_emb_cond
290
  text_mask = text_mask_cond
291
 
 
292
  def model_fn(x, t):
293
+ t_tensor = torch.as_tensor(t, device=x.device).expand(x.shape[0])
294
  return self.transformer(x, t_tensor, text_emb, txt_mask=text_mask)
295
 
 
296
  from artflow.flow.solvers import sample_ode
297
 
298
+ with self._get_autocast_context():
299
+ latents = sample_ode(
300
+ model_fn,
301
+ latents,
302
+ steps=num_inference_steps,
303
+ solver=solver,
304
+ device=self.device,
305
+ )
306
 
 
307
  if do_cfg:
308
  latents_cond, latents_uncond = latents.chunk(2)
309
  latents = latents_uncond + guidance_scale * (latents_cond - latents_uncond)
310
+ if self.offload:
311
+ self.transformer.to("cpu")
312
+ torch.cuda.empty_cache()
313
 
314
+ # --- Stage 3: VAE decode (vae on GPU) ---
315
  if output_type == "latent":
316
  images = latents
317
  else:
318
+ if self.offload:
319
+ self.vae.to(self.device)
320
  images = self._decode_latents(latents)
321
+ if self.offload:
322
+ self.vae.to("cpu")
323
+ torch.cuda.empty_cache()
324
 
325
  if not return_dict:
326
  return images
 
338
  prompts, self.text_encoder, self.tokenizer, pooling=pooling
339
  )
340
 
341
+ txt_emb = txt_emb.to(device=self.device, dtype=self.dtype)
342
+ txt_mask = txt_mask.to(device=self.device)
 
343
  if txt_pooled is not None:
344
+ txt_pooled = txt_pooled.to(device=self.device, dtype=self.dtype)
345
 
346
  return txt_emb, txt_mask, txt_pooled
347
 
348
  def _decode_latents(self, latents: torch.Tensor) -> List[Image.Image]:
349
  """Decode VAE latents to PIL images."""
350
+ latents = latents.to(device=self.device, dtype=self.dtype)
351
+
352
  # Denormalize
353
  if self.vae_mean is not None and self.vae_std is not None:
354
  latents = latents * self.vae_std + self.vae_mean
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cefd3c90533fabe566ce0864105d7fb7c2434d989eb6e714fa5cdff079ed5dd3
3
  size 2715352432
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:619009befe1afec50e1936d18afd3bbb8b7d314265c7975110cd8b17aee49fad
3
  size 2715352432