sharper740 commited on
Commit
570e350
·
verified ·
1 Parent(s): b0fa850

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RobertML.png filter=lfs diff=lfs merge=lfs -text
37
+ backup.png filter=lfs diff=lfs merge=lfs -text
pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "An edge-maxxing model submission by RobertML for the 4090 Flux contest"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.5.1",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ ]
23
+
24
+ [[tool.edge-maxxing.models]]
25
+ repository = "madebyollin/taef1"
26
+ revision = "2d552378e58c9c94201075708d7de4e1163b2689"
27
+
28
+ [[tool.edge-maxxing.models]]
29
+ repository = "director432/Flux1-Schnell"
30
+ revision = "f8da0eb2b421c7677c70312bd9dec91a71f411d3"
31
+
32
+ [[tool.edge-maxxing.models]]
33
+ repository = "director432/Flux1-Transformer2D"
34
+ revision = "803893c49df2bb29d3f0f89ef5467781bef64b25"
35
+
36
+ [[tool.edge-maxxing.models]]
37
+ repository = "director432/Flux1-T5Encoder"
38
+ revision = "93fa999c3acb891488a05eebfb6a98e31d574d05"
39
+
40
+
41
+ [project.scripts]
42
+ start_inference = "main:main"
43
+
src/__pycache__/main.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
src/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
src/flux_schnell_edge_inference.egg-info/PKG-INFO ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: flux-schnell-edge-inference
3
+ Version: 7
4
+ Summary: An edge-maxxing model submission for the 4090 Flux contest
5
+ Requires-Python: <3.13,>=3.10
6
+ Requires-Dist: diffusers==0.31.0
7
+ Requires-Dist: transformers==4.46.2
8
+ Requires-Dist: accelerate==1.1.0
9
+ Requires-Dist: omegaconf==2.3.0
10
+ Requires-Dist: torch==2.5.1
11
+ Requires-Dist: protobuf==5.28.3
12
+ Requires-Dist: sentencepiece==0.2.0
13
+ Requires-Dist: edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
14
+ Requires-Dist: gitpython>=3.1.43
15
+ Requires-Dist: torchao>=0.6.1
src/flux_schnell_edge_inference.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/main.py
4
+ src/pipeline.py
5
+ src/flux_schnell_edge_inference.egg-info/PKG-INFO
6
+ src/flux_schnell_edge_inference.egg-info/SOURCES.txt
7
+ src/flux_schnell_edge_inference.egg-info/dependency_links.txt
8
+ src/flux_schnell_edge_inference.egg-info/entry_points.txt
9
+ src/flux_schnell_edge_inference.egg-info/requires.txt
10
+ src/flux_schnell_edge_inference.egg-info/top_level.txt
src/flux_schnell_edge_inference.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/flux_schnell_edge_inference.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ start_inference = main:main
src/flux_schnell_edge_inference.egg-info/requires.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers==4.46.2
3
+ accelerate==1.1.0
4
+ omegaconf==2.3.0
5
+ torch==2.5.1
6
+ protobuf==5.28.3
7
+ sentencepiece==0.2.0
8
+ edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
9
+ gitpython>=3.1.43
10
+ torchao>=0.6.1
src/flux_schnell_edge_inference.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ main
2
+ pipeline
src/ghanta.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Callable
3
+ def hacer_nada(x: torch.Tensor, modo: str = None):
4
+ return x
5
+ def brujeria_mps(entrada, dim, indice):
6
+ if entrada.shape[-1] == 1:
7
+ return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1)
8
+ else:
9
+ return torch.gather(entrada, dim, indice)
10
+ def emparejamiento_suave_aleatorio_2d(
11
+ metrica: torch.Tensor,
12
+ ancho: int,
13
+ alto: int,
14
+ paso_x: int,
15
+ paso_y: int,
16
+ radio: int,
17
+ sin_aleatoriedad: bool = False,
18
+ generador: torch.Generator = None
19
+ ) -> Tuple[Callable, Callable]:
20
+ lote, num_nodos, _ = metrica.shape
21
+ if radio <= 0:
22
+ return hacer_nada, hacer_nada
23
+ recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather
24
+ with torch.no_grad():
25
+ alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x
26
+ if sin_aleatoriedad:
27
+ indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64)
28
+ else:
29
+ indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device)
30
+ vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64)
31
+ vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype))
32
+ vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x)
33
+ if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho:
34
+ buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64)
35
+ buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice
36
+ else:
37
+ buffer_indice = vista_buffer_indice
38
+ indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1)
39
+ del buffer_indice, vista_buffer_indice
40
+ num_destino = alto_paso_y * ancho_paso_x
41
+ indices_a = indice_aleatorio[:, num_destino:, :]
42
+ indices_b = indice_aleatorio[:, :num_destino, :]
43
+ def dividir(x):
44
+ canales = x.shape[-1]
45
+ origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales))
46
+ destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales))
47
+ return origen, destino
48
+ metrica = metrica / metrica.norm(dim=-1, keepdim=True)
49
+ a, b = dividir(metrica)
50
+ puntuaciones = a @ b.transpose(-1, -2)
51
+ radio = min(a.shape[1], radio)
52
+ nodo_max, nodo_indice = puntuaciones.max(dim=-1)
53
+ indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None]
54
+ indice_no_emparejado = indice_borde[..., radio:, :]
55
+ indice_origen = indice_borde[..., :radio, :]
56
+ indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen)
57
+ def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor:
58
+ origen, destino = dividir(x)
59
+ n, t1, c = origen.shape
60
+ no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c))
61
+ origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c))
62
+ destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo)
63
+ return torch.cat([no_emparejado, destino], dim=1)
64
+ def desfusionar(x: torch.Tensor) -> torch.Tensor:
65
+ longitud_no_emparejado = indice_no_emparejado.shape[1]
66
+ no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :]
67
+ _, _, c = no_emparejado.shape
68
+ origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c))
69
+ salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype)
70
+ salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino)
71
+ salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado)
72
+ salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen)
73
+ return salida
74
+ return fusionar, desfusionar
src/main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ from git import Repo
8
+ import torch
9
+
10
+ from PIL.JpegImagePlugin import JpegImageFile
11
+ from pipelines.models import TextToImageRequest
12
+ from pipeline import load_pipeline, infer
13
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
14
+
15
+
16
+ def at_exit():
17
+ torch.cuda.empty_cache()
18
+
19
+
20
+ def main():
21
+ atexit.register(at_exit)
22
+
23
+ print(f"Loading pipeline")
24
+ pipeline = load_pipeline()
25
+
26
+ print(f"Pipeline loaded! , creating socket at '{SOCKET}'")
27
+
28
+ if exists(SOCKET):
29
+ remove(SOCKET)
30
+
31
+ with Listener(SOCKET) as listener:
32
+ chmod(SOCKET, 0o777)
33
+
34
+ print(f"Awaiting connections")
35
+ with listener.accept() as connection:
36
+ print(f"Connected")
37
+ generator = torch.Generator("cuda")
38
+ while True:
39
+ try:
40
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
41
+ except EOFError:
42
+ print(f"Inference socket exiting")
43
+
44
+ return
45
+ image = infer(request, pipeline, generator.manual_seed(request.seed))
46
+ data = BytesIO()
47
+ image.save(data, format=JpegImageFile.format)
48
+
49
+ packet = data.getvalue()
50
+
51
+ connection.send_bytes(packet )
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
src/pipeline.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
2
+ from diffusers.image_processor import VaeImageProcessor
3
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
4
+ from huggingface_hub.constants import HF_HUB_CACHE
5
+ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
6
+ import torch
7
+ import torch._dynamo
8
+ import gc
9
+ from PIL import Image as img
10
+ from PIL.Image import Image
11
+ from pipelines.models import TextToImageRequest
12
+ from torch import Generator
13
+ import time
14
+ import os
15
+
16
+ import torch
17
+ import math
18
+ from typing import Type, Dict, Any, Tuple, Callable, Optional, Union, List
19
+ import ghanta
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
27
+ from diffusers.models.attention import FeedForward
28
+ from diffusers.models.attention_processor import (
29
+ Attention,
30
+ AttentionProcessor,
31
+ FluxAttnProcessor2_0,
32
+ FusedFluxAttnProcessor2_0,
33
+ )
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
36
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers
37
+ from diffusers.utils.import_utils import is_torch_npu_available
38
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
39
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
40
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
41
+
42
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
43
+ from diffusers.models.autoencoders import AutoencoderKL
44
+ # from diffusers.models.transformers import FluxTransformer2DModel
45
+ from diffusers.utils import (
46
+ USE_PEFT_BACKEND,
47
+ is_torch_xla_available,
48
+ replace_example_docstring,
49
+ scale_lora_layers,
50
+ unscale_lora_layers,
51
+ )
52
+ from diffusers.utils.torch_utils import randn_tensor
53
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
54
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
55
+
56
+ torch._dynamo.config.suppress_errors = True
57
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
58
+ dtype = torch.bfloat16
59
+ device = "cuda"
60
+ if is_torch_xla_available():
61
+ import torch_xla.core.xla_model as xm
62
+ XLA_AVAILABLE = True
63
+ else:
64
+ XLA_AVAILABLE = False
65
+
66
+ def calc_shift(img_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.16):
67
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
68
+ b = base_shift - m * base_seq_len
69
+ return img_seq_len * m + b
70
+
71
+ def get_timesteps(scheduler, num_steps: Optional[int] = None, dev: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs):
72
+ if timesteps is not None:
73
+ scheduler.set_timesteps(timesteps=timesteps, device=dev, **kwargs)
74
+ num_steps = len(scheduler.timesteps)
75
+ elif sigmas is not None:
76
+ scheduler.set_timesteps(sigmas=sigmas, device=dev, **kwargs)
77
+ num_steps = len(scheduler.timesteps)
78
+ else:
79
+ scheduler.set_timesteps(num_steps, device=dev, **kwargs)
80
+ return scheduler.timesteps, num_steps
81
+
82
+ def init_generator(dev: torch.device, backup: Optional[torch.Generator] = None):
83
+ if dev.type == "cpu":
84
+ return torch.Generator(device="cpu").set_state(torch.get_rng_state())
85
+ elif dev.type == "cuda":
86
+ return torch.Generator(device=dev).set_state(torch.cuda.get_rng_state())
87
+ else:
88
+ return init_generator(torch.device("cpu")) if backup is None else backup
89
+
90
+ def compute_fusion(x: torch.Tensor, info: Dict[str, Any]) -> Tuple[Callable, ...]:
91
+ h_orig, w_orig = info["size"]
92
+ tokens = h_orig * w_orig
93
+ sub_sample = int(math.ceil(math.sqrt(tokens // x.shape[1])))
94
+ args = info["args"]
95
+ if sub_sample <= args["down"]:
96
+ new_w = int(math.ceil(w_orig / sub_sample))
97
+ new_h = int(math.ceil(h_orig / sub_sample))
98
+ ratio = int(x.shape[1] * args["ratio"])
99
+ if args["generator"] is None:
100
+ args["generator"] = init_generator(x.device)
101
+ elif args["generator"].device != x.device:
102
+ args["generator"] = init_generator(x.device, backup=args["generator"])
103
+ use_rand = args["rando"]
104
+ fusion, defusion = ghanta.emparejamiento_suave_aleatorio_2d(x, new_w, new_h, args["sx"], args["sy"], ratio, sin_aleatoriedad=not use_rand, generador=args["generator"])
105
+ else:
106
+ fusion, defusion = (ghanta.hacer_nada, ghanta.hacer_nada)
107
+ fusion_a, defusion_a = (fusion, defusion) if args["m1"] else (ghanta.hacer_nada, ghanta.hacer_nada)
108
+ fusion_c, defusion_c = (fusion, defusion) if args["m2"] else (ghanta.hacer_nada, ghanta.hacer_nada)
109
+ fusion_m, defusion_m = (fusion, defusion) if args["m3"] else (ghanta.hacer_nada, ghanta.hacer_nada)
110
+ return fusion_a, fusion_c, fusion_m, defusion_a, defusion_c, defusion_m
111
+
112
+ @maybe_allow_in_graph
113
+ class SingleTransformerBlock(nn.Module):
114
+ def __init__(self, dim, num_heads, head_dim, mlp_ratio=4.0):
115
+ super().__init__()
116
+ self.mlp_hidden = int(dim * mlp_ratio)
117
+ self.norm = AdaLayerNormZeroSingle(dim)
118
+ self.mlp_proj = nn.Linear(dim, self.mlp_hidden)
119
+ self.mlp_act = nn.GELU(approximate="tanh")
120
+ self.out_proj = nn.Linear(dim + self.mlp_hidden, dim)
121
+ proc = FluxAttnProcessor2_0()
122
+ self.attn = Attention(query_dim=dim, cross_attention_dim=None, dim_head=head_dim, heads=num_heads, out_dim=dim, bias=True, processor=proc, qk_norm="rms_norm", eps=1e-6, pre_only=True)
123
+ def forward(self, hidden: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None, tinfo: Dict[str, Any] = None):
124
+ if tinfo is not None:
125
+ m_a, m_c, mom, u_a, u_c, u_m = compute_fusion(hidden, tinfo)
126
+ else:
127
+ m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
128
+ residual = hidden
129
+ norm_hidden, gate = self.norm(hidden, emb=temb)
130
+ norm_hidden = m_a(norm_hidden)
131
+ mlp_hidden = self.mlp_act(self.mlp_proj(norm_hidden))
132
+ joint_attention_kwargs = joint_attention_kwargs or {}
133
+ attn_out = self.attn(hidden_states=norm_hidden, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs)
134
+ hidden = torch.cat([attn_out, mlp_hidden], dim=2)
135
+ gate = gate.unsqueeze(1)
136
+ hidden = gate * self.out_proj(hidden)
137
+ hidden = u_a(residual + hidden)
138
+ return hidden
139
+
140
+ @maybe_allow_in_graph
141
+ class TransformerBlock(nn.Module):
142
+ def __init__(self, dim, num_heads, head_dim, qk_norm="rms_norm", eps=1e-6):
143
+ super().__init__()
144
+ self.norm1 = AdaLayerNormZero(dim)
145
+ self.norm1_context = AdaLayerNormZero(dim)
146
+ if hasattr(F, "scaled_dot_product_attention"):
147
+ proc = FluxAttnProcessor2_0()
148
+ self.attn = Attention(query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=head_dim, heads=num_heads, out_dim=dim, context_pre_only=False, bias=True, processor=proc, qk_norm=qk_norm, eps=eps)
149
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
150
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
151
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
152
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
153
+ self._chunk_size = None
154
+ self._chunk_dim = 0
155
+ def forward(self, hidden: torch.FloatTensor, encoder_hidden: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None, tinfo: Dict[str, Any] = None):
156
+ if tinfo is not None:
157
+ m_a, m_c, mom, u_a, u_c, u_m = compute_fusion(hidden, tinfo)
158
+ else:
159
+ m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
160
+ norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden, emb=temb)
161
+ norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden, emb=temb)
162
+ joint_attention_kwargs = joint_attention_kwargs or {}
163
+ norm_hidden = m_a(norm_hidden)
164
+ norm_encoder = m_c(norm_encoder)
165
+ attn_out, ctx_attn_out = self.attn(hidden_states=norm_hidden, encoder_hidden_states=norm_encoder, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs)
166
+ attn_out = gate_msa.unsqueeze(1) * attn_out
167
+ hidden = u_a(attn_out) + hidden
168
+ norm_hidden = self.norm2(hidden)
169
+ norm_hidden = norm_hidden * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
170
+ norm_hidden = mom(norm_hidden)
171
+ ff_out = self.ff(norm_hidden)
172
+ ff_out = gate_mlp.unsqueeze(1) * ff_out
173
+ hidden = u_m(ff_out) + hidden
174
+ ctx_attn_out = c_gate_msa.unsqueeze(1) * ctx_attn_out
175
+ encoder_hidden = u_c(ctx_attn_out) + encoder_hidden
176
+ norm_encoder = self.norm2_context(encoder_hidden)
177
+ norm_encoder = norm_encoder * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
178
+ ctx_ff_out = self.ff_context(norm_encoder)
179
+ encoder_hidden = encoder_hidden + c_gate_mlp.unsqueeze(1) * ctx_ff_out
180
+ return encoder_hidden, hidden
181
+
182
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
183
+ _supports_gradient_checkpointing = True
184
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
185
+ @register_to_config
186
+ def __init__(self, patch_size: int = 1, in_channels: int = 64, out_channels: Optional[int] = None, guidance_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), generator: Optional[torch.Generator] = None):
187
+ super().__init__()
188
+ self.out_channels = out_channels or in_channels
189
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
190
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
191
+ text_time_cls = CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
192
+ self.time_text_embed = text_time_cls(embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim)
193
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
194
+ self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
195
+ self.transformer_blocks = nn.ModuleList([TransformerBlock(dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim) for _ in range(self.config.num_layers)])
196
+ self.single_transformer_blocks = nn.ModuleList([SingleTransformerBlock(dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim) for _ in range(self.config.num_single_layers)])
197
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
198
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
199
+ ratio: float = 0.5
200
+ down: int = 1
201
+ sx: int = 2
202
+ sy: int = 2
203
+ rando: bool = False
204
+ m1: bool = True
205
+ m2: bool = True
206
+ m3: bool = False
207
+ self.tinfo = {"size": None, "args": {"ratio": ratio, "down": down, "sx": sx, "sy": sy, "rando": rando, "m1": m1, "m2": m2, "m3": m3, "generator": generator}}
208
+ self.gradient_checkpointing = False
209
+ @property
210
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
211
+ procs = {}
212
+ def add_processors(name: str, module: torch.nn.Module, procs: Dict[str, AttentionProcessor]):
213
+ if hasattr(module, "get_processor"):
214
+ procs[f"{name}.processor"] = module.get_processor()
215
+ for sub_name, child in module.named_children():
216
+ add_processors(f"{name}.{sub_name}", child, procs)
217
+ return procs
218
+ for name, module in self.named_children():
219
+ add_processors(name, module, procs)
220
+ return procs
221
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
222
+ cnt = len(self.attn_processors.keys())
223
+ def set_proc(name: str, module: torch.nn.Module, processor):
224
+ if hasattr(module, "set_processor"):
225
+ if not isinstance(processor, dict):
226
+ module.set_processor(processor)
227
+ else:
228
+ module.set_processor(processor.pop(f"{name}.processor"))
229
+ for sub_name, child in module.named_children():
230
+ set_proc(f"{name}.{sub_name}", child, processor)
231
+ for name, module in self.named_children():
232
+ set_proc(name, module, processor)
233
+ def fuse_qkv_projections(self):
234
+ self.original_attn_processors = None
235
+ self.original_attn_processors = self.attn_processors
236
+ for module in self.modules():
237
+ if isinstance(module, Attention):
238
+ module.fuse_projections(fuse=True)
239
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
240
+ def unfuse_qkv_projections(self):
241
+ if self.original_attn_processors is not None:
242
+ self.set_attn_processor(self.original_attn_processors)
243
+ def _set_gradient_checkpointing(self, module, value=False):
244
+ if hasattr(module, "gradient_checkpointing"):
245
+ module.gradient_checkpointing = value
246
+ def forward(self, hidden: torch.Tensor, encoder_hidden: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
247
+ if len(hidden.shape) == 4:
248
+ self.tinfo["size"] = (hidden.shape[2], hidden.shape[3])
249
+ if len(hidden.shape) == 3:
250
+ self.tinfo["size"] = (hidden.shape[1], hidden.shape[2])
251
+ if joint_attention_kwargs is not None:
252
+ joint_attention_kwargs = joint_attention_kwargs.copy()
253
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
254
+ else:
255
+ lora_scale = 1.0
256
+ if USE_PEFT_BACKEND:
257
+ scale_lora_layers(self, lora_scale)
258
+ hidden = self.x_embedder(hidden)
259
+ timestep = timestep.to(hidden.dtype) * 1000
260
+ if guidance is not None:
261
+ guidance = guidance.to(hidden.dtype) * 1000
262
+ else:
263
+ guidance = None
264
+ temb = self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections)
265
+ encoder_hidden = self.context_embedder(encoder_hidden)
266
+ if img_ids.ndim == 3:
267
+ img_ids = img_ids[0]
268
+ ids = torch.cat((txt_ids, img_ids), dim=0)
269
+ image_rotary_emb = self.pos_embed(ids)
270
+ for index, block in enumerate(self.transformer_blocks):
271
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
272
+ def custom_forward(module, return_dict=None):
273
+ def cf(*inputs):
274
+ return module(*inputs, return_dict=return_dict) if return_dict is not None else module(*inputs)
275
+ return cf
276
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
277
+ encoder_hidden, hidden = torch.utils.checkpoint.checkpoint(custom_forward(block), hidden, encoder_hidden, temb, image_rotary_emb, **ckpt_kwargs)
278
+ else:
279
+ encoder_hidden, hidden = block(hidden_states=hidden, encoder_hidden_states=encoder_hidden, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, tinfo=self.tinfo)
280
+ if controlnet_block_samples is not None:
281
+ interval = len(self.transformer_blocks) / len(controlnet_block_samples)
282
+ interval = int(np.ceil(interval))
283
+ if controlnet_blocks_repeat:
284
+ hidden = hidden + controlnet_block_samples[index % len(controlnet_block_samples)]
285
+ else:
286
+ hidden = hidden + controlnet_block_samples[index // interval]
287
+ hidden = torch.cat([encoder_hidden, hidden], dim=1)
288
+ for index, block in enumerate(self.single_transformer_blocks):
289
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
290
+ def custom_forward(module, return_dict=None):
291
+ def cf(*inputs):
292
+ return module(*inputs, return_dict=return_dict) if return_dict is not None else module(*inputs)
293
+ return cf
294
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
295
+ hidden = torch.utils.checkpoint.checkpoint(custom_forward(block), hidden, temb, image_rotary_emb, **ckpt_kwargs)
296
+ else:
297
+ hidden = block(hidden_states=hidden, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, tinfo=self.tinfo)
298
+ if controlnet_single_block_samples is not None:
299
+ interval = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
300
+ interval = int(np.ceil(interval))
301
+ hidden[:, encoder_hidden.shape[1]:, ...] = hidden[:, encoder_hidden.shape[1]:, ...] + controlnet_single_block_samples[index // interval]
302
+ hidden = hidden[:, encoder_hidden.shape[1]:, ...]
303
+ hidden = self.norm_out(hidden, temb)
304
+ output = self.proj_out(hidden)
305
+ if USE_PEFT_BACKEND:
306
+ unscale_lora_layers(self, lora_scale)
307
+ if not return_dict:
308
+ return (output,)
309
+ return Transformer2DModelOutput(sample=output)
310
+
311
+ class FluxPipeLine(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin):
312
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
313
+ def __init__(self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: Transformer2DModel):
314
+ super().__init__()
315
+ self.register_modules(vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler)
316
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
317
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
318
+ self.tokenizer_max_length = self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
319
+ self.default_sample_size = 128
320
+ def _get_t5_prompt_embeds(self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, dev: Optional[torch.device] = None, dt: Optional[torch.dtype] = None):
321
+ dev = dev or self._execution_device
322
+ dt = dt or self.text_encoder.dtype
323
+ prompt = [prompt] if isinstance(prompt, str) else prompt
324
+ batch_size = len(prompt)
325
+ if isinstance(self, TextualInversionLoaderMixin):
326
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
327
+ text_inputs = self.tokenizer_2(prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt")
328
+ text_input_ids = text_inputs.input_ids
329
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
330
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
331
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
332
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(dev), output_hidden_states=False)[0]
333
+ dt = self.text_encoder_2.dtype
334
+ prompt_embeds = prompt_embeds.to(dtype=dt, device=dev)
335
+ _, seq_len, _ = prompt_embeds.shape
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
338
+ return prompt_embeds
339
+ def _get_clip_prompt_embeds(self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, dev: Optional[torch.device] = None):
340
+ dev = dev or self._execution_device
341
+ prompt = [prompt] if isinstance(prompt, str) else prompt
342
+ batch_size = len(prompt)
343
+ if isinstance(self, TextualInversionLoaderMixin):
344
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
345
+ text_inputs = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt")
346
+ text_input_ids = text_inputs.input_ids
347
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
348
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
349
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
350
+ prompt_embeds = self.text_encoder(text_input_ids.to(dev), output_hidden_states=False)
351
+ prompt_embeds = prompt_embeds.pooler_output
352
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=dev)
353
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
354
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
355
+ return prompt_embeds
356
+ def encode_prompt(self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], dev: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None):
357
+ dev = dev or self._execution_device
358
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
359
+ self._lora_scale = lora_scale
360
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
361
+ scale_lora_layers(self.text_encoder, lora_scale)
362
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
363
+ scale_lora_layers(self.text_encoder_2, lora_scale)
364
+ prompt = [prompt] if isinstance(prompt, str) else prompt
365
+ if prompt_embeds is None:
366
+ prompt_2 = prompt_2 or prompt
367
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
368
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(prompt=prompt, dev=dev, num_images_per_prompt=num_images_per_prompt)
369
+ prompt_embeds = self._get_t5_prompt_embeds(prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, dev=dev)
370
+ if self.text_encoder is not None:
371
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
372
+ unscale_lora_layers(self.text_encoder, lora_scale)
373
+ if self.text_encoder_2 is not None:
374
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
375
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
376
+ dt = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
377
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=dev, dtype=dt)
378
+ return prompt_embeds, pooled_prompt_embeds, text_ids
379
+ @staticmethod
380
+ def _prepare_latent_ids(batch_size, height, width, dev, dt):
381
+ latent_ids = torch.zeros(height, width, 3)
382
+ latent_ids[..., 1] = latent_ids[..., 1] + torch.arange(height)[:, None]
383
+ latent_ids[..., 2] = latent_ids[..., 2] + torch.arange(width)[None, :]
384
+ h, w, c = latent_ids.shape
385
+ latent_ids = latent_ids.reshape(h * w, c)
386
+ return latent_ids.to(device=dev, dtype=dt)
387
+ @staticmethod
388
+ def _pack_latents(latents, batch_size, num_channels, height, width):
389
+ latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
390
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
391
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
392
+ return latents
393
+ @staticmethod
394
+ def _unpack_latents(latents, height, width, scale_factor):
395
+ batch_size, num_patches, channels = latents.shape
396
+ height = 2 * (int(height) // (scale_factor * 2))
397
+ width = 2 * (int(width) // (scale_factor * 2))
398
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
399
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
400
+ latents = latents.reshape(batch_size, channels // 4, height, width)
401
+ return latents
402
+ def prepare_latents(self, batch_size, num_channels, height, width, dt, dev, generator, latents=None):
403
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
404
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
405
+ shape = (batch_size, num_channels, height, width)
406
+ if latents is not None:
407
+ latent_ids = self._prepare_latent_ids(batch_size, height // 2, width // 2, dev, dt)
408
+ return latents.to(device=dev, dtype=dt), latent_ids
409
+ latents = randn_tensor(shape, generator=generator, device=dev, dtype=dt)
410
+ latents = self._pack_latents(latents, batch_size, num_channels, height, width)
411
+ latent_ids = self._prepare_latent_ids(batch_size, height // 2, width // 2, dev, dt)
412
+ return latents, latent_ids
413
+ @property
414
+ def guidance_scale(self):
415
+ return self._guidance_scale
416
+ @property
417
+ def joint_attention_kwargs(self):
418
+ return self._joint_attention_kwargs
419
+ @property
420
+ def num_timesteps(self):
421
+ return self._num_timesteps
422
+ @property
423
+ def interrupt(self):
424
+ return self._interrupt
425
+ @torch.no_grad()
426
+ def __call__(self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512):
427
+ height = height or self.default_sample_size * self.vae_scale_factor
428
+ width = width or self.default_sample_size * self.vae_scale_factor
429
+ self._guidance_scale = guidance_scale
430
+ self._joint_attention_kwargs = joint_attention_kwargs
431
+ self._interrupt = False
432
+ if prompt is not None and isinstance(prompt, str):
433
+ batch_size = 1
434
+ elif prompt is not None and isinstance(prompt, list):
435
+ batch_size = len(prompt)
436
+ else:
437
+ batch_size = prompt_embeds.shape[0]
438
+ dev = self._execution_device
439
+ lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
440
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, dev=dev, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
441
+ num_channels = self.transformer.config.in_channels // 4
442
+ latents, latent_ids = self.prepare_latents(batch_size * num_images_per_prompt, num_channels, height, width, prompt_embeds.dtype, dev, generator, latents)
443
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
444
+ image_seq_len = latents.shape[1]
445
+ mu = calc_shift(image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift)
446
+ timesteps, num_inference_steps = get_timesteps(self.scheduler, num_inference_steps, dev, sigmas=sigmas, mu=mu)
447
+ self._num_timesteps = len(timesteps)
448
+ if self.transformer.config.guidance_embeds:
449
+ guidance = torch.full([1], guidance_scale, device=dev, dtype=torch.float32)
450
+ guidance = guidance.expand(latents.shape[0])
451
+ else:
452
+ guidance = None
453
+ for i, t in enumerate(timesteps):
454
+ if self.interrupt:
455
+ continue
456
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
457
+ noise_pred = self.transformer(hidden_states=latents, timestep=timestep, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False)[0]
458
+ latents_dtype = latents.dtype
459
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
460
+ if latents.dtype != latents_dtype:
461
+ if torch.backends.mps.is_available():
462
+ latents = latents.to(latents_dtype)
463
+ if callback_on_step_end is not None:
464
+ callback_kwargs = {}
465
+ for k in callback_on_step_end_tensor_inputs:
466
+ callback_kwargs[k] = locals()[k]
467
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
468
+ latents = callback_outputs.pop("latents", latents)
469
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
470
+ if XLA_AVAILABLE:
471
+ xm.mark_step()
472
+ if output_type == "latent":
473
+ image = latents
474
+ else:
475
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
476
+ latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
477
+ image = self.vae.decode(latents, return_dict=False)[0]
478
+ image = self.image_processor.postprocess(image, output_type=output_type)
479
+ self.maybe_free_model_hooks()
480
+ if not return_dict:
481
+ return (image,)
482
+ return FluxPipelineOutput(images=image)
483
+
484
+ Pipeline = None
485
+ torch.backends.cuda.matmul.allow_tf32 = True
486
+ torch.backends.cudnn.enabled = True
487
+ torch.backends.cudnn.benchmark = True
488
+ ckpt_id = "director432/Flux1-Schnell"
489
+ ckpt_rev = "f8da0eb2b421c7677c70312bd9dec91a71f411d3"
490
+
491
+ def clear_cache():
492
+ gc.collect()
493
+ torch.cuda.empty_cache()
494
+ torch.cuda.reset_max_memory_allocated()
495
+ torch.cuda.reset_peak_memory_stats()
496
+
497
+ def load_pipeline() -> FluxPipeLine:
498
+ clear_cache()
499
+ text_enc2 = T5EncoderModel.from_pretrained("director432/Flux1-T5Encoder", revision="93fa999c3acb891488a05eebfb6a98e31d574d05", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
500
+ path = os.path.join(HF_HUB_CACHE, "models--director432--Flux1-Transformer2D/snapshots/803893c49df2bb29d3f0f89ef5467781bef64b25")
501
+ gen = torch.Generator(device=device)
502
+ model = Transformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator=gen).to(memory_format=torch.channels_last)
503
+ pipe = FluxPipeLine.from_pretrained(ckpt_id, revision=ckpt_rev, transformer=model, text_encoder_2=text_enc2, torch_dtype=dtype).to(device)
504
+ for _ in range(3):
505
+ pipe(prompt="director cooper follow interests activity recent organize", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
506
+ clear_cache()
507
+ return pipe
508
+
509
+ @torch.no_grad()
510
+ def infer(request: TextToImageRequest, pipe: FluxPipeLine, gen: Generator) -> Image:
511
+ return pipe(request.prompt, generator=gen, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
uv.lock ADDED
The diff for this file is too large to render. See raw diff