manbeast3b commited on
Commit
6d5fa4c
·
verified ·
1 Parent(s): 3a3114f

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +4 -76
src/pipeline.py CHANGED
@@ -39,83 +39,16 @@ from diffusers.utils.import_utils import is_torch_npu_available
39
  from diffusers.utils.torch_utils import maybe_allow_in_graph
40
  from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
41
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
-
43
  from model import E, D
44
  import torchvision
45
-
46
  import os
 
47
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
48
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
49
  torch._dynamo.config.suppress_errors = True
50
 
51
- class BasicQuantization:
52
- def __init__(self, bits=1):
53
- self.bits = bits
54
- self.qmin = -(2**(bits-1))
55
- self.qmax = 2**(bits-1) - 1
56
-
57
- def quantize_tensor(self, tensor):
58
- scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
59
- zero_point = self.qmin - torch.round(tensor.min() / scale)
60
- qtensor = torch.round(tensor / scale + zero_point)
61
- qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
62
- return (qtensor - zero_point) * scale, scale, zero_point
63
-
64
- class ModelQuantization:
65
- def __init__(self, model, bits=7):
66
- self.model = model
67
- self.quant = BasicQuantization(bits)
68
-
69
- def quantize_model(self):
70
- for name, module in self.model.named_modules():
71
- if isinstance(module, torch.nn.Linear):
72
- if hasattr(module, 'weightML'):
73
- quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
74
- module.weight = torch.nn.Parameter(quantized_weight)
75
- if hasattr(module, 'bias') and module.bias is not None:
76
- quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
77
- module.bias = torch.nn.Parameter(quantized_bias)
78
-
79
-
80
- def inicializar_generador(dispositivo: torch.device, respaldo: torch.Generator = None):
81
- if dispositivo.type == "cpu":
82
- return torch.Generator(device="cpu").set_state(torch.get_rng_state())
83
- elif dispositivo.type == "cuda":
84
- return torch.Generator(device=dispositivo).set_state(torch.cuda.get_rng_state())
85
- else:
86
- if respaldo is None:
87
- return inicializar_generador(torch.device("cpu"))
88
- else:
89
- return respaldo
90
 
91
- def calcular_fusion(x: torch.Tensor, info_tome: Dict[str, Any]) -> Tuple[Callable, ...]:
92
- alto_original, ancho_original = info_tome["size"]
93
- tokens_originales = alto_original * ancho_original
94
- submuestreo = int(math.ceil(math.sqrt(tokens_originales // x.shape[1])))
95
- argumentos = info_tome["args"]
96
- if submuestreo <= argumentos["down"]:
97
- ancho = int(math.ceil(ancho_original / submuestreo))
98
- alto = int(math.ceil(alto_original / submuestreo))
99
- radio = int(x.shape[1] * argumentos["ratio"])
100
-
101
- if argumentos["generator"] is None:
102
- argumentos["generator"] = inicializar_generador(x.device)
103
- elif argumentos["generator"].device != x.device:
104
- argumentos["generator"] = inicializar_generador(x.device, respaldo=argumentos["generator"])
105
-
106
- usar_aleatoriedad = argumentos["rando"]
107
- fusion, desfusion = ghanta.emparejamiento_suave_aleatorio_2d(
108
- x, ancho, alto, argumentos["sx"], argumentos["sy"], radio,
109
- sin_aleatoriedad=not usar_aleatoriedad, generador=argumentos["generator"]
110
- )
111
- else:
112
- fusion, desfusion = (hacer_nada, hacer_nada)
113
- fusion_a, desfusion_a = (fusion, desfusion) if argumentos["m1"] else (hacer_nada, hacer_nada)
114
- fusion_c, desfusion_c = (fusion, desfusion) if argumentos["m2"] else (hacer_nada, hacer_nada)
115
- fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada)
116
- return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m
117
-
118
- from diffusers import FluxPipeline, FluxTransformer2DModel
119
  Pipeline = None
120
  torch.backends.cuda.matmul.allow_tf32 = True
121
  torch.backends.cudnn.enabled = True
@@ -127,10 +60,6 @@ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
127
  TinyVAE = "madebyollin/taef1"
128
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
129
 
130
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
131
- os.environ["TOKENIZERS_PARALLELISM"] = "True"
132
- torch._dynamo.config.suppress_errors = True
133
-
134
  def empty_cache():
135
  gc.collect()
136
  torch.cuda.empty_cache()
@@ -172,8 +101,7 @@ def load_pipeline() -> Pipeline:
172
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
173
  return pipeline
174
 
175
-
176
- sample = None
177
  @torch.no_grad()
178
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
179
  global sample
@@ -181,4 +109,4 @@ def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator)
181
  sample=1
182
  empty_cache()
183
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
184
- return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))# torchvision.transforms.functional.to_pil_image(image)
 
39
  from diffusers.utils.torch_utils import maybe_allow_in_graph
40
  from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
41
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
+ from diffusers import FluxPipeline, FluxTransformer2DModel
43
  from model import E, D
44
  import torchvision
 
45
  import os
46
+
47
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
48
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
49
  torch._dynamo.config.suppress_errors = True
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  Pipeline = None
53
  torch.backends.cuda.matmul.allow_tf32 = True
54
  torch.backends.cudnn.enabled = True
 
60
  TinyVAE = "madebyollin/taef1"
61
  TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
62
 
 
 
 
 
63
  def empty_cache():
64
  gc.collect()
65
  torch.cuda.empty_cache()
 
101
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
102
  return pipeline
103
 
104
+ sample = 1
 
105
  @torch.no_grad()
106
  def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
107
  global sample
 
109
  sample=1
110
  empty_cache()
111
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
112
+ return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))