manbeast3b commited on
Commit
f074760
·
verified ·
1 Parent(s): 5f3f08c

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +44 -493
src/pipeline.py CHANGED
@@ -112,469 +112,6 @@ def calcular_fusion(x: torch.Tensor, info_tome: Dict[str, Any]) -> Tuple[Callabl
112
  fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada)
113
  return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m
114
 
115
- @maybe_allow_in_graph
116
- class FluxSingleTransformerBlock(nn.Module):
117
-
118
- def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
119
- super().__init__()
120
- self.mlp_hidden_dim = int(dim * mlp_ratio)
121
-
122
- self.norm = AdaLayerNormZeroSingle(dim)
123
- self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
124
- self.act_mlp = nn.GELU(approximate="tanh")
125
- self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
126
-
127
- processor = FluxAttnProcessor2_0()
128
- self.attn = Attention(
129
- query_dim=dim,
130
- cross_attention_dim=None,
131
- dim_head=attention_head_dim,
132
- heads=num_attention_heads,
133
- out_dim=dim,
134
- bias=True,
135
- processor=processor,
136
- qk_norm="rms_norm",
137
- eps=1e-6,
138
- pre_only=True,
139
- )
140
-
141
- def forward(
142
- self,
143
- hidden_states: torch.FloatTensor,
144
- temb: torch.FloatTensor,
145
- image_rotary_emb=None,
146
- joint_attention_kwargs=None,
147
- tinfo: Dict[str, Any] = None,
148
- ):
149
- if tinfo is not None:
150
- m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo)
151
- else:
152
- m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
153
-
154
- residual = hidden_states
155
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
156
- norm_hidden_states = m_a(norm_hidden_states)
157
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
158
- joint_attention_kwargs = joint_attention_kwargs or {}
159
- attn_output = self.attn(
160
- hidden_states=norm_hidden_states,
161
- image_rotary_emb=image_rotary_emb,
162
- **joint_attention_kwargs,
163
- )
164
-
165
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
166
- gate = gate.unsqueeze(1)
167
- hidden_states = gate * self.proj_out(hidden_states)
168
- hidden_states = u_a(residual + hidden_states)
169
-
170
- return hidden_states
171
-
172
-
173
- @maybe_allow_in_graph
174
- class FluxTransformerBlock(nn.Module):
175
-
176
- def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
177
- super().__init__()
178
-
179
- self.norm1 = AdaLayerNormZero(dim)
180
-
181
- self.norm1_context = AdaLayerNormZero(dim)
182
-
183
- if hasattr(F, "scaled_dot_product_attention"):
184
- processor = FluxAttnProcessor2_0()
185
- else:
186
- raise ValueError(
187
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
188
- )
189
- self.attn = Attention(
190
- query_dim=dim,
191
- cross_attention_dim=None,
192
- added_kv_proj_dim=dim,
193
- dim_head=attention_head_dim,
194
- heads=num_attention_heads,
195
- out_dim=dim,
196
- context_pre_only=False,
197
- bias=True,
198
- processor=processor,
199
- qk_norm=qk_norm,
200
- eps=eps,
201
- )
202
-
203
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
204
- self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
205
-
206
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
207
- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
208
- self._chunk_size = None
209
- self._chunk_dim = 0
210
-
211
- def forward(
212
- self,
213
- hidden_states: torch.FloatTensor,
214
- encoder_hidden_states: torch.FloatTensor,
215
- temb: torch.FloatTensor,
216
- image_rotary_emb=None,
217
- joint_attention_kwargs=None,
218
- tinfo: Dict[str, Any] = None, # Add tinfo parameter
219
- ):
220
- if tinfo is not None:
221
- m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo)
222
- else:
223
- m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
224
-
225
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
226
-
227
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
228
- encoder_hidden_states, emb=temb
229
- )
230
- joint_attention_kwargs = joint_attention_kwargs or {}
231
- norm_hidden_states = m_a(norm_hidden_states)
232
- norm_encoder_hidden_states = m_c(norm_encoder_hidden_states)
233
-
234
- attn_output, context_attn_output = self.attn(
235
- hidden_states=norm_hidden_states,
236
- encoder_hidden_states=norm_encoder_hidden_states,
237
- image_rotary_emb=image_rotary_emb,
238
- **joint_attention_kwargs,
239
- )
240
-
241
- attn_output = gate_msa.unsqueeze(1) * attn_output
242
- hidden_states = u_a(attn_output) + hidden_states
243
-
244
- norm_hidden_states = self.norm2(hidden_states)
245
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
246
-
247
- norm_hidden_states = mom(norm_hidden_states)
248
-
249
- ff_output = self.ff(norm_hidden_states)
250
- ff_output = gate_mlp.unsqueeze(1) * ff_output
251
-
252
- hidden_states = u_m(ff_output) + hidden_states
253
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
254
- encoder_hidden_states = u_c(context_attn_output) + encoder_hidden_states
255
-
256
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
257
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
258
-
259
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
260
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
261
-
262
- return encoder_hidden_states, hidden_states
263
-
264
-
265
- class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
266
-
267
- _supports_gradient_checkpointing = True
268
- _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
269
-
270
- @register_to_config
271
- def __init__(
272
- self,
273
- patch_size: int = 1,
274
- in_channels: int = 64,
275
- out_channels: Optional[int] = None,
276
- num_layers: int = 19,
277
- num_single_layers: int = 38,
278
- attention_head_dim: int = 128,
279
- num_attention_heads: int = 24,
280
- joint_attention_dim: int = 4096,
281
- pooled_projection_dim: int = 768,
282
- guidance_embeds: bool = False,
283
- axes_dims_rope: Tuple[int] = (16, 56, 56),
284
- generator: Optional[torch.Generator] = None,
285
- ):
286
- super().__init__()
287
- self.out_channels = out_channels or in_channels
288
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
289
-
290
- self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
291
-
292
- text_time_guidance_cls = (
293
- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
294
- )
295
- self.time_text_embed = text_time_guidance_cls(
296
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
297
- )
298
-
299
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
300
- self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
301
-
302
- self.transformer_blocks = nn.ModuleList(
303
- [
304
- FluxTransformerBlock(
305
- dim=self.inner_dim,
306
- num_attention_heads=self.config.num_attention_heads,
307
- attention_head_dim=self.config.attention_head_dim,
308
- )
309
- for i in range(self.config.num_layers)
310
- ]
311
- )
312
-
313
- self.single_transformer_blocks = nn.ModuleList(
314
- [
315
- FluxSingleTransformerBlock(
316
- dim=self.inner_dim,
317
- num_attention_heads=self.config.num_attention_heads,
318
- attention_head_dim=self.config.attention_head_dim,
319
- )
320
- for i in range(self.config.num_single_layers)
321
- ]
322
- )
323
-
324
- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
325
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
326
- ratio: float = 0.3
327
- down: int = 1
328
- sx: int = 2
329
- sy: int = 2
330
- rando: bool = False
331
- m1: bool = False
332
- m2: bool = True
333
- m3: bool = False
334
-
335
- self.tinfo = {
336
- "size": None,
337
- "args": {
338
- "ratio": ratio,
339
- "down": down,
340
- "sx": sx,
341
- "sy": sy,
342
- "rando": rando,
343
- "m1": m1,
344
- "m2": m2,
345
- "m3": m3,
346
- "generator": generator
347
- }
348
- }
349
-
350
- self.gradient_checkpointing = False
351
-
352
- @property
353
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
354
- r"""
355
- Returns:
356
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
357
- indexed by its weight name.
358
- """
359
- processors = {}
360
-
361
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
362
- if hasattr(module, "get_processor"):
363
- processors[f"{name}.processor"] = module.get_processor()
364
-
365
- for sub_name, child in module.named_children():
366
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
367
-
368
- return processors
369
-
370
- for name, module in self.named_children():
371
- fn_recursive_add_processors(name, module, processors)
372
-
373
- return processors
374
-
375
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
376
- count = len(self.attn_processors.keys())
377
-
378
- if isinstance(processor, dict) and len(processor) != count:
379
- raise ValueError(
380
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
381
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
382
- )
383
-
384
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
385
- if hasattr(module, "set_processor"):
386
- if not isinstance(processor, dict):
387
- module.set_processor(processor)
388
- else:
389
- module.set_processor(processor.pop(f"{name}.processor"))
390
-
391
- for sub_name, child in module.named_children():
392
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
393
-
394
- for name, module in self.named_children():
395
- fn_recursive_attn_processor(name, module, processor)
396
-
397
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
398
- def fuse_qkv_projections(self):
399
- self.original_attn_processors = None
400
-
401
- for _, attn_processor in self.attn_processors.items():
402
- if "Added" in str(attn_processor.__class__.__name__):
403
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
404
-
405
- self.original_attn_processors = self.attn_processors
406
-
407
- for module in self.modules():
408
- if isinstance(module, Attention):
409
- module.fuse_projections(fuse=True)
410
-
411
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
412
-
413
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
414
- def unfuse_qkv_projections(self):
415
- if self.original_attn_processors is not None:
416
- self.set_attn_processor(self.original_attn_processors)
417
-
418
- def _set_gradient_checkpointing(self, module, value=False):
419
- if hasattr(module, "gradient_checkpointing"):
420
- module.gradient_checkpointing = value
421
-
422
- def forward(
423
- self,
424
- hidden_states: torch.Tensor,
425
- encoder_hidden_states: torch.Tensor = None,
426
- pooled_projections: torch.Tensor = None,
427
- timestep: torch.LongTensor = None,
428
- img_ids: torch.Tensor = None,
429
- txt_ids: torch.Tensor = None,
430
- guidance: torch.Tensor = None,
431
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
432
- controlnet_block_samples=None,
433
- controlnet_single_block_samples=None,
434
- return_dict: bool = True,
435
- controlnet_blocks_repeat: bool = False,
436
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
437
- if joint_attention_kwargs is not None:
438
- joint_attention_kwargs = joint_attention_kwargs.copy()
439
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
440
- else:
441
- lora_scale = 1.0
442
-
443
- if USE_PEFT_BACKEND:
444
- # weight the lora layers by setting `lora_scale` for each PEFT layer
445
- scale_lora_layers(self, lora_scale)
446
- else:
447
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
448
- logger.warning(
449
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
450
- )
451
-
452
- hidden_states = self.x_embedder(hidden_states)
453
- if len(hidden_states.shape) == 4:
454
- self.tinfo["size"] = (hidden_states.shape[2], hidden_states.shape[3])
455
-
456
- timestep = timestep.to(hidden_states.dtype) * 1000
457
- if guidance is not None:
458
- guidance = guidance.to(hidden_states.dtype) * 1000
459
- else:
460
- guidance = None
461
-
462
- temb = (
463
- self.time_text_embed(timestep, pooled_projections)
464
- if guidance is None
465
- else self.time_text_embed(timestep, guidance, pooled_projections)
466
- )
467
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
468
-
469
- if txt_ids.ndim == 3:
470
- logger.warning(
471
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
472
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
473
- )
474
- txt_ids = txt_ids[0]
475
- if img_ids.ndim == 3:
476
- logger.warning(
477
- "Passing `img_ids` 3d torch.Tensor is deprecated."
478
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
479
- )
480
- img_ids = img_ids[0]
481
-
482
- ids = torch.cat((txt_ids, img_ids), dim=0)
483
- image_rotary_emb = self.pos_embed(ids)
484
-
485
- for index_block, block in enumerate(self.transformer_blocks):
486
- if torch.is_grad_enabled() and self.gradient_checkpointing:
487
-
488
- def create_custom_forward(module, return_dict=None):
489
- def custom_forward(*inputs):
490
- if return_dict is not None:
491
- return module(*inputs, return_dict=return_dict)
492
- else:
493
- return module(*inputs)
494
-
495
- return custom_forward
496
-
497
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
498
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
499
- create_custom_forward(block),
500
- hidden_states,
501
- encoder_hidden_states,
502
- temb,
503
- image_rotary_emb,
504
- **ckpt_kwargs,
505
- )
506
-
507
- else:
508
- encoder_hidden_states, hidden_states = block(
509
- hidden_states=hidden_states,
510
- encoder_hidden_states=encoder_hidden_states,
511
- temb=temb,
512
- image_rotary_emb=image_rotary_emb,
513
- joint_attention_kwargs=joint_attention_kwargs,
514
- )
515
-
516
- if controlnet_block_samples is not None:
517
- interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
518
- interval_control = int(np.ceil(interval_control))
519
- if controlnet_blocks_repeat:
520
- hidden_states = (
521
- hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
522
- )
523
- else:
524
- hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
525
-
526
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
527
-
528
- for index_block, block in enumerate(self.single_transformer_blocks):
529
- if torch.is_grad_enabled() and self.gradient_checkpointing:
530
-
531
- def create_custom_forward(module, return_dict=None):
532
- def custom_forward(*inputs):
533
- if return_dict is not None:
534
- return module(*inputs, return_dict=return_dict)
535
- else:
536
- return module(*inputs)
537
-
538
- return custom_forward
539
-
540
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
541
- hidden_states = torch.utils.checkpoint.checkpoint(
542
- create_custom_forward(block),
543
- hidden_states,
544
- temb,
545
- image_rotary_emb,
546
- **ckpt_kwargs,
547
- )
548
-
549
- else:
550
- hidden_states = block(
551
- hidden_states=hidden_states,
552
- temb=temb,
553
- image_rotary_emb=image_rotary_emb,
554
- joint_attention_kwargs=joint_attention_kwargs,
555
- )
556
-
557
- if controlnet_single_block_samples is not None:
558
- interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
559
- interval_control = int(np.ceil(interval_control))
560
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
561
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
562
- + controlnet_single_block_samples[index_block // interval_control]
563
- )
564
-
565
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
566
-
567
- hidden_states = self.norm_out(hidden_states, temb)
568
- output = self.proj_out(hidden_states)
569
-
570
- if USE_PEFT_BACKEND:
571
- unscale_lora_layers(self, lora_scale)
572
-
573
- if not return_dict:
574
- return (output,)
575
-
576
- return Transformer2DModelOutput(sample=output)
577
-
578
  from diffusers import FluxPipeline, FluxTransformer2DModel
579
  Pipeline = None
580
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -594,40 +131,54 @@ def empty_cache():
594
  torch.cuda.reset_max_memory_allocated()
595
  torch.cuda.reset_peak_memory_stats()
596
 
597
- def load_pipeline() -> Pipeline:
598
- empty_cache()
599
- test = "Pneumonoultramicroscopicsilicovolcanoconiosis, Floccinaucinihilipilification, Pseudopseudohypoparathyroidism, Antidisestablishmentarianism, Supercalifragilisticexpialidocious, Honorificabilitudinitatibus"
600
 
601
- dtype, device = torch.bfloat16, "cuda"
602
- text_encoder_2 = T5EncoderModel.from_pretrained(
603
- "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
604
- ).to(memory_format=torch.channels_last)
605
- path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
606
- model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, local_files_only=True)
607
- vae = AutoencoderTiny.from_pretrained(
608
- TinyVAE,
609
- revision=TinyVAE_REV,
610
- local_files_only=True,
611
- torch_dtype=torch.bfloat16)
612
- pipeline = FluxPipeline.from_pretrained(
613
- ckpt_id,
614
- revision=ckpt_revision,
615
- transformer=model,
616
- vae=vae,
617
- # text_encoder_2=text_encoder_2,
618
- torch_dtype=dtype,
619
- )
620
- pipeline.transformer.to(memory_format=torch.channels_last)
621
- pipeline.vae.to(memory_format=torch.channels_last)
622
- quantize_(pipeline.vae, int8_weight_only())
623
- pipeline.vae = torch.compile(pipeline.vae, mode="reduce-overhead", fullgraph=True)
624
- pipeline.to(device)
625
- # pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
626
- for _ in range(2):
627
- pipeline(prompt=test, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
628
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  return pipeline
630
 
 
 
631
  sample = None
632
  @torch.no_grad()
633
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
 
112
  fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada)
113
  return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  from diffusers import FluxPipeline, FluxTransformer2DModel
116
  Pipeline = None
117
  torch.backends.cuda.matmul.allow_tf32 = True
 
131
  torch.cuda.reset_max_memory_allocated()
132
  torch.cuda.reset_peak_memory_stats()
133
 
134
+ # def load_pipeline() -> Pipeline:
135
+ # empty_cache()
136
+ # test = "Pneumonoultramicroscopicsilicovolcanoconiosis, Floccinaucinihilipilification, Pseudopseudohypoparathyroidism, Antidisestablishmentarianism, Supercalifragilisticexpialidocious, Honorificabilitudinitatibus"
137
 
138
+ # dtype, device = torch.bfloat16, "cuda"
139
+ # text_encoder_2 = T5EncoderModel.from_pretrained(
140
+ # "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
141
+ # ).to(memory_format=torch.channels_last)
142
+ # path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
143
+ # model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, local_files_only=True)
144
+ # vae = AutoencoderTiny.from_pretrained(
145
+ # TinyVAE,
146
+ # revision=TinyVAE_REV,
147
+ # local_files_only=True,
148
+ # torch_dtype=torch.bfloat16)
149
+ # pipeline = FluxPipeline.from_pretrained(
150
+ # ckpt_id,
151
+ # revision=ckpt_revision,
152
+ # transformer=model,
153
+ # vae=vae,
154
+ # # text_encoder_2=text_encoder_2,
155
+ # torch_dtype=dtype,
156
+ # )
157
+ # pipeline.transformer.to(memory_format=torch.channels_last)
158
+ # pipeline.vae.to(memory_format=torch.channels_last)
159
+ # quantize_(pipeline.vae, int8_weight_only())
160
+ # pipeline.vae = torch.compile(pipeline.vae, mode="reduce-overhead", fullgraph=True)
161
+ # pipeline.to(device)
162
+ # # pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
163
+ # for _ in range(2):
164
+ # pipeline(prompt=test, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
165
 
166
+ # return pipeline
167
+
168
+
169
+ def load_pipeline() -> Pipeline:
170
+ path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
171
+ transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
172
+ pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
173
+ pipeline.to("cuda")
174
+ quantize_(pipeline.vae, int8_weight_only())
175
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
176
+ for _ in range(3):
177
+ pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
178
  return pipeline
179
 
180
+
181
+
182
  sample = None
183
  @torch.no_grad()
184
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: