manbeast3b commited on
Commit
7eb7442
·
verified ·
1 Parent(s): c538eba

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +31 -164
src/pipeline.py CHANGED
@@ -1,142 +1,13 @@
1
- # from diffusers import FluxPipeline, AutoencoderKL
2
- # from diffusers.image_processor import VaeImageProcessor
3
- # from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
4
- # import torch
5
- # import gc
6
- # from PIL.Image import Image
7
- # from pipelines.models import TextToImageRequest
8
- # from torch import Generator
9
-
10
- # Pipeline = None
11
-
12
- # CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
13
-
14
- # def empty_cache():
15
- # gc.collect()
16
- # torch.cuda.empty_cache()
17
- # torch.cuda.reset_max_memory_allocated()
18
- # torch.cuda.reset_peak_memory_stats()
19
-
20
- # def load_pipeline() -> Pipeline:
21
- # infer(TextToImageRequest(prompt=""), Pipeline)
22
-
23
- # return Pipeline
24
-
25
-
26
- # def encode_prompt(prompt: str):
27
- # text_encoder = CLIPTextModel.from_pretrained(
28
- # CHECKPOINT,
29
- # subfolder="text_encoder",
30
- # torch_dtype=torch.bfloat16,
31
- # )
32
-
33
- # text_encoder_2 = T5EncoderModel.from_pretrained(
34
- # CHECKPOINT,
35
- # subfolder="text_encoder_2",
36
- # torch_dtype=torch.bfloat16,
37
- # )
38
-
39
- # tokenizer = CLIPTokenizer.from_pretrained(CHECKPOINT, subfolder="tokenizer")
40
- # tokenizer_2 = T5TokenizerFast.from_pretrained(CHECKPOINT, subfolder="tokenizer_2")
41
-
42
- # pipeline = FluxPipeline.from_pretrained(
43
- # CHECKPOINT,
44
- # text_encoder=text_encoder,
45
- # text_encoder_2=text_encoder_2,
46
- # tokenizer=tokenizer,
47
- # tokenizer_2=tokenizer_2,
48
- # transformer=None,
49
- # vae=None,
50
- # ).to("cuda")
51
-
52
- # with torch.no_grad():
53
- # return pipeline.encode_prompt(
54
- # prompt=prompt,
55
- # prompt_2=None,
56
- # max_sequence_length=256,
57
- # )
58
-
59
-
60
- # def infer_latents(prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None):
61
- # pipeline = FluxPipeline.from_pretrained(
62
- # CHECKPOINT,
63
- # text_encoder=None,
64
- # text_encoder_2=None,
65
- # tokenizer=None,
66
- # tokenizer_2=None,
67
- # vae=None,
68
- # torch_dtype=torch.bfloat16,
69
- # ).to("cuda")
70
-
71
- # if seed is None:
72
- # generator = None
73
- # else:
74
- # generator = Generator(pipeline.device).manual_seed(seed)
75
-
76
- # return pipeline(
77
- # prompt_embeds=prompt_embeds,
78
- # pooled_prompt_embeds=pooled_prompt_embeds,
79
- # num_inference_steps=4,
80
- # guidance_scale=0.0,
81
- # width=width,
82
- # height=height,
83
- # generator=generator,
84
- # output_type="latent",
85
- # ).images
86
-
87
-
88
- # def infer(request: TextToImageRequest, _pipeline: Pipeline) -> Image:
89
- # empty_cache()
90
-
91
- # prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(request.prompt)
92
-
93
- # empty_cache()
94
-
95
- # latents = infer_latents(prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed)
96
-
97
- # empty_cache()
98
-
99
- # vae = AutoencoderKL.from_pretrained(
100
- # CHECKPOINT,
101
- # subfolder="vae",
102
- # torch_dtype=torch.bfloat16,
103
- # ).to("cuda")
104
-
105
- # vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
106
- # image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
107
-
108
- # height = request.height or 64 * vae_scale_factor
109
- # width = request.width or 64 * vae_scale_factor
110
-
111
- # with torch.no_grad():
112
- # latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
113
- # latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
114
-
115
- # image = vae.decode(latents, return_dict=False)[0]
116
- # return image_processor.postprocess(image, output_type="pil")[0]
117
-
118
-
119
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
120
- from diffusers.image_processor import VaeImageProcessor
121
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
122
-
123
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
124
  import torch
125
  import torch.nn as nn
126
- import torch._dynamo
127
  import gc
128
- from PIL import Image as img
129
  from PIL.Image import Image
130
  from pipelines.models import TextToImageRequest
131
  from torch import Generator
132
- import time
133
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
134
- # from torchao.quantization import quantize_,int8_weight_only
135
  import os
136
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
137
- Pipeline = None
138
-
139
-
140
 
141
  def fp8_linear_forward(cls, original_dtype, input):
142
  weight_dtype = cls.weight.dtype
@@ -171,76 +42,72 @@ def fp8_linear_forward(cls, original_dtype, input):
171
  else:
172
  return cls.original_forward(input)
173
 
174
- def convert_fp8_linear(module, original_dtype):
175
  setattr(module, "fp8_matmul_enabled", True)
176
  for name, module in module.named_modules():
177
  if isinstance(module, nn.Linear):
178
  if "blocks" in name:
179
- print("changing")
180
- #print(module, name)
181
  original_forward = module.forward
182
  setattr(module, "original_forward", original_forward)
183
  setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
184
 
185
 
186
- def replace_with_fp8_linear(transformer, original_dtype):
187
- """
188
- Replace all nn.Linear layers in the transformer with FP8-enabled linear layers.
189
- """
190
  for name, module in transformer.named_modules():
191
  if isinstance(module, nn.Linear):
192
- # Use your custom function to convert to FP8 Linear
193
- convert_fp8_linear(module, original_dtype)
194
  return transformer
195
 
196
- ckpt_id = "black-forest-labs/FLUX.1-schnell"
197
- def empty_cache():
198
- start = time.time()
 
199
  gc.collect()
200
- torch.cuda.empty_cache()
201
  torch.cuda.reset_max_memory_allocated()
202
  torch.cuda.reset_peak_memory_stats()
203
- print(f"Flush took: {time.time() - start}")
204
 
205
  def load_pipeline() -> Pipeline:
206
- empty_cache()
207
- dtype, device = torch.bfloat16, "cuda"
208
-
209
- text_encoder_2 = T5EncoderModel.from_pretrained(
210
- "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
211
- )
212
- vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
213
  pipeline = DiffusionPipeline.from_pretrained(
214
- ckpt_id,
215
  vae=vae,
216
- text_encoder_2 = text_encoder_2,
217
- torch_dtype=dtype,
218
  )
 
219
  torch.backends.cudnn.benchmark = True
220
  torch.backends.cuda.matmul.allow_tf32 = True
221
  torch.cuda.set_per_process_memory_fraction(0.9)
222
  pipeline.text_encoder.to(memory_format=torch.channels_last)
 
223
  pipeline.transformer.to(memory_format=torch.channels_last)
224
-
225
- # Replace Linear layers in the Transformer with FP8 Linear layers
226
- # pipeline.text_encoder_2 = replace_with_fp8_linear(pipeline.text_encoder_2, original_dtype=dtype)
227
- pipeline.vae = replace_with_fp8_linear(pipeline.vae, original_dtype=dtype)
228
-
229
-
230
  pipeline.vae.to(memory_format=torch.channels_last)
231
  pipeline.vae = torch.compile(pipeline.vae)
232
-
233
  pipeline._exclude_from_cpu_offload = ["vae"]
234
  pipeline.enable_sequential_cpu_offload()
235
- for _ in range(2):
 
 
 
236
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
237
 
238
  return pipeline
239
 
240
 
 
241
  @torch.inference_mode()
242
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
243
- torch.cuda.reset_peak_memory_stats()
 
 
 
244
  generator = Generator("cuda").manual_seed(request.seed)
245
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
246
  return(image)
 
1
+ from diffusers import AutoencoderKL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
 
4
  import gc
 
5
  from PIL.Image import Image
6
  from pipelines.models import TextToImageRequest
7
  from torch import Generator
8
+ from diffusers import DiffusionPipeline
 
 
9
  import os
10
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
 
 
 
11
 
12
  def fp8_linear_forward(cls, original_dtype, input):
13
  weight_dtype = cls.weight.dtype
 
42
  else:
43
  return cls.original_forward(input)
44
 
45
+ def convert(module, original_dtype):
46
  setattr(module, "fp8_matmul_enabled", True)
47
  for name, module in module.named_modules():
48
  if isinstance(module, nn.Linear):
49
  if "blocks" in name:
 
 
50
  original_forward = module.forward
51
  setattr(module, "original_forward", original_forward)
52
  setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
53
 
54
 
55
+ def replace(transformer, original_dtype):
 
 
 
56
  for name, module in transformer.named_modules():
57
  if isinstance(module, nn.Linear):
58
+ convert(module, original_dtype)
 
59
  return transformer
60
 
61
+ Pipeline = None
62
+ MODEL_ID = "black-forest-labs/FLUX.1-schnell"
63
+ DTYPE = torch.bfloat16
64
+ def clear():
65
  gc.collect()
 
66
  torch.cuda.reset_max_memory_allocated()
67
  torch.cuda.reset_peak_memory_stats()
68
+ torch.cuda.empty_cache()
69
 
70
  def load_pipeline() -> Pipeline:
71
+ # restart
72
+ clear()
73
+
74
+ # setup
75
+ text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=DTYPE)
76
+ vae=AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=DTYPE)
 
77
  pipeline = DiffusionPipeline.from_pretrained(
78
+ MODEL_ID,
79
  vae=vae,
80
+ text_encoder_2=text_encoder_2,
81
+ torch_dtype=DTYPE,
82
  )
83
+ # optimize
84
  torch.backends.cudnn.benchmark = True
85
  torch.backends.cuda.matmul.allow_tf32 = True
86
  torch.cuda.set_per_process_memory_fraction(0.9)
87
  pipeline.text_encoder.to(memory_format=torch.channels_last)
88
+ pipeline.text_encoder_2.to(memory_format=torch.channels_last)
89
  pipeline.transformer.to(memory_format=torch.channels_last)
90
+ pipeline.vae = replace(pipeline.vae, original_dtype=dtype)
 
 
 
 
 
91
  pipeline.vae.to(memory_format=torch.channels_last)
92
  pipeline.vae = torch.compile(pipeline.vae)
 
93
  pipeline._exclude_from_cpu_offload = ["vae"]
94
  pipeline.enable_sequential_cpu_offload()
95
+
96
+ # warm up once
97
+ clear()
98
+ for _ in range(1):
99
  pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
100
 
101
  return pipeline
102
 
103
 
104
+ sample = True
105
  @torch.inference_mode()
106
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
107
+ global sample
108
+ if sample:
109
+ clear()
110
+ sample = None
111
  generator = Generator("cuda").manual_seed(request.seed)
112
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
113
  return(image)