manbeast3b commited on
Commit
3a3114f
·
0 Parent(s):

Initial commit

Browse files
Files changed (9) hide show
  1. .gitattributes +37 -0
  2. ko.pth +3 -0
  3. ok.pth +3 -0
  4. pyproject.toml +55 -0
  5. src/ghanta.py +74 -0
  6. src/main.py +55 -0
  7. src/model.py +52 -0
  8. src/pipeline.py +184 -0
  9. uv.lock +0 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RobertML.png filter=lfs diff=lfs merge=lfs -text
37
+ backup.png filter=lfs diff=lfs merge=lfs -text
ko.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2077712511cbeb96f4d0a6a0898b78345302ddaaf196d384f69c3d9c1adad6f9
3
+ size 4951464
ok.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b15aacc35d11e08803e9fdf07a2eed0c7f861250352f23e91cbeda4be07ad914
3
+ size 1800013
pyproject.toml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "An edge-maxxing model submission by RobertML for the 4090 Flux contest"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.5.1",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ "setuptools>=75.3.0",
23
+ "torchvision"
24
+ ]
25
+
26
+ [[tool.edge-maxxing.models]]
27
+ repository = "black-forest-labs/FLUX.1-schnell"
28
+ revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
29
+ exclude = ["transformer"]
30
+
31
+ [[tool.edge-maxxing.models]]
32
+ repository = "RobertML/FLUX.1-schnell-int8wo"
33
+ revision = "307e0777d92df966a3c0f99f31a6ee8957a9857a"
34
+
35
+ [[tool.edge-maxxing.models]]
36
+ repository = "city96/t5-v1_1-xxl-encoder-bf16"
37
+ revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
38
+
39
+ [[tool.edge-maxxing.models]]
40
+ repository = "RobertML/FLUX.1-schnell-vae_e3m2"
41
+ revision = "da0d2cd7815792fb40d084dbd8ed32b63f153d8d"
42
+
43
+ [[tool.edge-maxxing.models]]
44
+ repository = "madebyollin/taef1"
45
+ revision = "2d552378e58c9c94201075708d7de4e1163b2689"
46
+
47
+ [[tool.edge-maxxing.models]]
48
+ repository = "manbeast3b/flux.1-schnell-full1"
49
+ revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
50
+
51
+
52
+
53
+ [project.scripts]
54
+ start_inference = "main:main"
55
+
src/ghanta.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Callable
3
+ def hacer_nada(x: torch.Tensor, modo: str = None):
4
+ return x
5
+ def brujeria_mps(entrada, dim, indice):
6
+ if entrada.shape[-1] == 1:
7
+ return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1)
8
+ else:
9
+ return torch.gather(entrada, dim, indice)
10
+ def emparejamiento_suave_aleatorio_2d(
11
+ metrica: torch.Tensor,
12
+ ancho: int,
13
+ alto: int,
14
+ paso_x: int,
15
+ paso_y: int,
16
+ radio: int,
17
+ sin_aleatoriedad: bool = False,
18
+ generador: torch.Generator = None
19
+ ) -> Tuple[Callable, Callable]:
20
+ lote, num_nodos, _ = metrica.shape
21
+ if radio <= 0:
22
+ return hacer_nada, hacer_nada
23
+ recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather
24
+ with torch.no_grad():
25
+ alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x
26
+ if sin_aleatoriedad:
27
+ indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64)
28
+ else:
29
+ indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device)
30
+ vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64)
31
+ vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype))
32
+ vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x)
33
+ if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho:
34
+ buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64)
35
+ buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice
36
+ else:
37
+ buffer_indice = vista_buffer_indice
38
+ indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1)
39
+ del buffer_indice, vista_buffer_indice
40
+ num_destino = alto_paso_y * ancho_paso_x
41
+ indices_a = indice_aleatorio[:, num_destino:, :]
42
+ indices_b = indice_aleatorio[:, :num_destino, :]
43
+ def dividir(x):
44
+ canales = x.shape[-1]
45
+ origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales))
46
+ destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales))
47
+ return origen, destino
48
+ metrica = metrica / metrica.norm(dim=-1, keepdim=True)
49
+ a, b = dividir(metrica)
50
+ puntuaciones = a @ b.transpose(-1, -2)
51
+ radio = min(a.shape[1], radio)
52
+ nodo_max, nodo_indice = puntuaciones.max(dim=-1)
53
+ indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None]
54
+ indice_no_emparejado = indice_borde[..., radio:, :]
55
+ indice_origen = indice_borde[..., :radio, :]
56
+ indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen)
57
+ def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor:
58
+ origen, destino = dividir(x)
59
+ n, t1, c = origen.shape
60
+ no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c))
61
+ origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c))
62
+ destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo)
63
+ return torch.cat([no_emparejado, destino], dim=1)
64
+ def desfusionar(x: torch.Tensor) -> torch.Tensor:
65
+ longitud_no_emparejado = indice_no_emparejado.shape[1]
66
+ no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :]
67
+ _, _, c = no_emparejado.shape
68
+ origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c))
69
+ salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype)
70
+ salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino)
71
+ salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado)
72
+ salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen)
73
+ return salida
74
+ return fusionar, desfusionar
src/main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ from git import Repo
8
+ import torch
9
+
10
+ from PIL.JpegImagePlugin import JpegImageFile
11
+ from pipelines.models import TextToImageRequest
12
+ from pipeline import load_pipeline, infer
13
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
14
+
15
+
16
+ def at_exit():
17
+ torch.cuda.empty_cache()
18
+
19
+
20
+ def main():
21
+ atexit.register(at_exit)
22
+
23
+ print(f"Loading pipeline")
24
+ pipeline = load_pipeline()
25
+
26
+ print(f"Pipeline loaded, creating socket at '{SOCKET}'")
27
+
28
+ if exists(SOCKET):
29
+ remove(SOCKET)
30
+
31
+ with Listener(SOCKET) as listener:
32
+ chmod(SOCKET, 0o777)
33
+
34
+ print(f"Awaiting connections")
35
+ with listener.accept() as connection:
36
+ print(f"Connected")
37
+ generator = torch.Generator("cuda")
38
+ while True:
39
+ try:
40
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
41
+ except EOFError:
42
+ print(f"Inference socket exiting")
43
+
44
+ return
45
+ image = infer(request, pipeline, generator.manual_seed(request.seed))
46
+ data = BytesIO()
47
+ image.save(data, format=JpegImageFile.format)
48
+
49
+ packet = data.getvalue()
50
+
51
+ connection.send_bytes(packet )
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
src/model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t, torch.nn as nn, torch.nn.functional as F
2
+ def C(n_in, n_out, **kwargs):
3
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
4
+ class Clamp(nn.Module):
5
+ def forward(self, x):
6
+ return t.tanh(x / 3) * 3
7
+ class B(nn.Module):
8
+ def __init__(self, n_in, n_out):
9
+ super().__init__()
10
+ self.conv = nn.Sequential(C(n_in, n_out), nn.ReLU(), C(n_out, n_out), nn.ReLU(), C(n_out, n_out))
11
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
12
+ self.fuse = nn.ReLU()
13
+ def forward(self, x):
14
+ return self.fuse(self.conv(x) + self.skip(x))
15
+ def E(latent_channels=4):
16
+ return nn.Sequential(
17
+ C(3, 64), B(64, 64),
18
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
19
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
20
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
21
+ C(64, latent_channels),
22
+ )
23
+ def D(latent_channels=16):
24
+ return nn.Sequential(
25
+ Clamp(),
26
+ C(latent_channels, 48),nn.ReLU(),B(48, 48), B(48, 48),
27
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48), B(48, 48),
28
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
29
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
30
+ C(48, 3),
31
+ )
32
+ class M(nn.Module):
33
+ lm, ls = 3, 0.5
34
+ def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
35
+ super().__init__()
36
+ if lc is None: lc = s.glc(str(ep))
37
+ s.e, s.d = E(lc), D(lc)
38
+ def f(sd, mod, pfx):
39
+ f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
40
+ mod.load_state_dict(f_sd, strict=False)
41
+ if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
42
+ if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")
43
+ s.e.requires_grad_(False)
44
+ s.d.requires_grad_(False)
45
+ def glc(s, ep): return 16 if "taef1" in ep or "taesd3" in ep else 4
46
+ @staticmethod
47
+ def sl(x): return x.div(2 * M.lm).add(M.ls).clamp(0, 1)
48
+ @staticmethod
49
+ def ul(x): return x.sub(M.ls).mul(2 * M.lm)
50
+ def forward(s, x, rl=False):
51
+ l, o = s.e(x), s.d(s.e(x))
52
+ return (o.clamp(0, 1), l) if rl else o.clamp(0, 1)
src/pipeline.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
2
+ from diffusers.image_processor import VaeImageProcessor
3
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
4
+ from huggingface_hub.constants import HF_HUB_CACHE
5
+ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
6
+ import torch
7
+ import torch._dynamo
8
+ import gc
9
+ from PIL import Image as img
10
+ from PIL.Image import Image
11
+ from pipelines.models import TextToImageRequest
12
+ from torch import Generator
13
+ import time
14
+ from diffusers import DiffusionPipeline
15
+ from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
16
+
17
+ import torch
18
+ import math
19
+ from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
20
+ import ghanta
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
28
+ from diffusers.models.attention import FeedForward
29
+ from diffusers.models.attention_processor import (
30
+ Attention,
31
+ AttentionProcessor,
32
+ FluxAttnProcessor2_0,
33
+ FusedFluxAttnProcessor2_0,
34
+ )
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
37
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
38
+ from diffusers.utils.import_utils import is_torch_npu_available
39
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
40
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
41
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
+
43
+ from model import E, D
44
+ import torchvision
45
+
46
+ import os
47
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
48
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
49
+ torch._dynamo.config.suppress_errors = True
50
+
51
+ class BasicQuantization:
52
+ def __init__(self, bits=1):
53
+ self.bits = bits
54
+ self.qmin = -(2**(bits-1))
55
+ self.qmax = 2**(bits-1) - 1
56
+
57
+ def quantize_tensor(self, tensor):
58
+ scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
59
+ zero_point = self.qmin - torch.round(tensor.min() / scale)
60
+ qtensor = torch.round(tensor / scale + zero_point)
61
+ qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
62
+ return (qtensor - zero_point) * scale, scale, zero_point
63
+
64
+ class ModelQuantization:
65
+ def __init__(self, model, bits=7):
66
+ self.model = model
67
+ self.quant = BasicQuantization(bits)
68
+
69
+ def quantize_model(self):
70
+ for name, module in self.model.named_modules():
71
+ if isinstance(module, torch.nn.Linear):
72
+ if hasattr(module, 'weightML'):
73
+ quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
74
+ module.weight = torch.nn.Parameter(quantized_weight)
75
+ if hasattr(module, 'bias') and module.bias is not None:
76
+ quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
77
+ module.bias = torch.nn.Parameter(quantized_bias)
78
+
79
+
80
+ def inicializar_generador(dispositivo: torch.device, respaldo: torch.Generator = None):
81
+ if dispositivo.type == "cpu":
82
+ return torch.Generator(device="cpu").set_state(torch.get_rng_state())
83
+ elif dispositivo.type == "cuda":
84
+ return torch.Generator(device=dispositivo).set_state(torch.cuda.get_rng_state())
85
+ else:
86
+ if respaldo is None:
87
+ return inicializar_generador(torch.device("cpu"))
88
+ else:
89
+ return respaldo
90
+
91
+ def calcular_fusion(x: torch.Tensor, info_tome: Dict[str, Any]) -> Tuple[Callable, ...]:
92
+ alto_original, ancho_original = info_tome["size"]
93
+ tokens_originales = alto_original * ancho_original
94
+ submuestreo = int(math.ceil(math.sqrt(tokens_originales // x.shape[1])))
95
+ argumentos = info_tome["args"]
96
+ if submuestreo <= argumentos["down"]:
97
+ ancho = int(math.ceil(ancho_original / submuestreo))
98
+ alto = int(math.ceil(alto_original / submuestreo))
99
+ radio = int(x.shape[1] * argumentos["ratio"])
100
+
101
+ if argumentos["generator"] is None:
102
+ argumentos["generator"] = inicializar_generador(x.device)
103
+ elif argumentos["generator"].device != x.device:
104
+ argumentos["generator"] = inicializar_generador(x.device, respaldo=argumentos["generator"])
105
+
106
+ usar_aleatoriedad = argumentos["rando"]
107
+ fusion, desfusion = ghanta.emparejamiento_suave_aleatorio_2d(
108
+ x, ancho, alto, argumentos["sx"], argumentos["sy"], radio,
109
+ sin_aleatoriedad=not usar_aleatoriedad, generador=argumentos["generator"]
110
+ )
111
+ else:
112
+ fusion, desfusion = (hacer_nada, hacer_nada)
113
+ fusion_a, desfusion_a = (fusion, desfusion) if argumentos["m1"] else (hacer_nada, hacer_nada)
114
+ fusion_c, desfusion_c = (fusion, desfusion) if argumentos["m2"] else (hacer_nada, hacer_nada)
115
+ fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada)
116
+ return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m
117
+
118
+ from diffusers import FluxPipeline, FluxTransformer2DModel
119
+ Pipeline = None
120
+ torch.backends.cuda.matmul.allow_tf32 = True
121
+ torch.backends.cudnn.enabled = True
122
+ torch.backends.cudnn.benchmark = True
123
+
124
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
125
+ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
126
+
127
+ TinyVAE = "madebyollin/taef1"
128
+ TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
129
+
130
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
131
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
132
+ torch._dynamo.config.suppress_errors = True
133
+
134
+ def empty_cache():
135
+ gc.collect()
136
+ torch.cuda.empty_cache()
137
+ torch.cuda.reset_max_memory_allocated()
138
+ torch.cuda.reset_peak_memory_stats()
139
+
140
+
141
+ def load_pipeline() -> Pipeline:
142
+ path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
143
+ transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
144
+ vae = AutoencoderTiny.from_pretrained(
145
+ TinyVAE,
146
+ revision=TinyVAE_REV,
147
+ local_files_only=True,
148
+ torch_dtype=torch.bfloat16)
149
+ vae.encoder = E(16)
150
+ vae.decoder = D(16)
151
+ ko_state_dict = torch.load("ko.pth", map_location="cpu", weights_only=True)
152
+ filtered_state_dict = {k.strip('encoder.'): v for k, v in ko_state_dict.items() if k.strip('encoder.') in vae.encoder.state_dict() and v.size() == vae.encoder.state_dict()[k.strip('encoder.')].size()}
153
+ vae.encoder.load_state_dict(filtered_state_dict, strict=False)
154
+ vae.encoder.requires_grad_(False).to(dtype=torch.bfloat16)
155
+ ok_state_dict = torch.load("ok.pth", map_location="cpu", weights_only=True)
156
+ filtered_state_dict = {k.strip('decoder.'): v for k, v in ok_state_dict.items() if k.strip('decoder.') in vae.decoder.state_dict() and v.size() == vae.decoder.state_dict()[k.strip('decoder.')].size()}
157
+ vae.decoder.load_state_dict(filtered_state_dict, strict=False)
158
+ vae.decoder.requires_grad_(False).to(dtype=torch.bfloat16)
159
+
160
+ pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, vae=vae, local_files_only=True, torch_dtype=torch.bfloat16,)
161
+ pipeline.to("cuda")
162
+
163
+ # Optimize memory format
164
+ for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
165
+ component.to(memory_format=torch.channels_last)
166
+
167
+ # quantize_(pipeline.vae, int8_weight_only())
168
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
169
+ pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
170
+
171
+ for _ in range(2):
172
+ pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
173
+ return pipeline
174
+
175
+
176
+ sample = None
177
+ @torch.no_grad()
178
+ def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
179
+ global sample
180
+ if not sample:
181
+ sample=1
182
+ empty_cache()
183
+ image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
184
+ return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))# torchvision.transforms.functional.to_pil_image(image)
uv.lock ADDED
The diff for this file is too large to render. See raw diff