manbeast3b commited on
Commit
6746668
·
verified ·
1 Parent(s): ae03ef6

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +107 -92
src/pipeline.py CHANGED
@@ -1,35 +1,49 @@
1
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from diffusers.image_processor import VaeImageProcessor
3
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
4
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
5
- import torch
6
- import torch._dynamo
7
- import gc
8
  from PIL import Image as img
9
- from PIL.Image import Image
10
  from pipelines.models import TextToImageRequest
11
- from torch import Generator
12
- import time
13
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
14
- from torchao.quantization import quantize_, int8_weight_only
15
- import os
16
  from model import Encoder, Decoder
17
- import torchvision
18
- import torch.nn as nn
19
-
20
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
21
-
22
- Pipeline = None
23
 
 
 
24
  torch.backends.cudnn.benchmark = True
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.cuda.set_per_process_memory_fraction(0.95)
27
-
 
 
 
 
 
 
 
 
 
 
 
28
  class BasicQuantization:
29
  def __init__(self, bits=1):
30
  self.bits = bits
31
- self.qmin = -(2**(bits-1))
32
- self.qmax = 2**(bits-1) - 1
33
 
34
  def quantize_tensor(self, tensor):
35
  scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
@@ -45,107 +59,108 @@ class ModelQuantization:
45
 
46
  def quantize_model(self):
47
  for name, module in self.model.named_modules():
48
- if isinstance(module, torch.nn.Linear):
49
- if hasattr(module, 'weightML'):
50
  quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
51
- module.weight = torch.nn.Parameter(quantized_weight)
52
  if hasattr(module, 'bias') and module.bias is not None:
53
  quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
54
- module.bias = torch.nn.Parameter(quantized_bias)
55
-
56
-
57
- ckpt_id = "black-forest-labs/FLUX.1-schnell"
58
- def empty_cache():
59
- start = time.time()
60
- gc.collect()
61
- torch.cuda.empty_cache()
62
- torch.cuda.reset_max_memory_allocated()
63
- torch.cuda.reset_peak_memory_stats()
64
-
65
- def load_pipeline() -> Pipeline:
66
- empty_cache()
67
 
 
 
 
 
68
  dtype, device = torch.bfloat16, "cuda"
69
 
70
- vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", torch_dtype=dtype)
 
71
  vae.encoder = Encoder(16)
72
  vae.decoder = Decoder(16)
73
 
 
74
  encoder_path = "encoder.pth"
75
  decoder_path = "decoder.pth"
76
 
77
- if encoder_path is not None:
78
- encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
79
- filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in vae.encoder.state_dict() and v.size() == vae.encoder.state_dict()[k.strip('encoder.')].size()}
80
- print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(vae.encoder.state_dict())}")
 
 
 
81
  vae.encoder.load_state_dict(filtered_state_dict, strict=False)
82
  vae.encoder.to(dtype=dtype)
83
-
84
- if decoder_path is not None:
85
- decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
86
- filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in vae.decoder.state_dict() and v.size() == vae.decoder.state_dict()[k.strip('decoder.')].size()}
87
- print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(vae.decoder.state_dict())}")
 
 
 
88
  vae.decoder.load_state_dict(filtered_state_dict, strict=False)
89
  vae.decoder.to(dtype=dtype)
90
 
91
- vae.decoder.requires_grad_(False)
92
  vae.encoder.requires_grad_(False)
93
-
94
- # quantize_(vae, int8_weight_only())
 
95
  quantizer = ModelQuantization(vae)
96
  quantizer.quantize_model()
97
 
 
 
 
 
98
 
99
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(ckpt_id, subfolder="scheduler")
100
-
101
- ############ Text Encoder ############
102
- text_encoder = CLIPTextModel.from_pretrained(
103
- ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
104
- )
105
- # quantize_(text_encoder, int8_weight_only())
106
-
107
- ############ Text Encoder 2 ############
108
- text_encoder_2 = T5EncoderModel.from_pretrained(
109
- "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
110
  )
111
 
112
- model = FluxTransformer2DModel.from_pretrained(
113
- "/root/.cache/huggingface/hub/models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a", torch_dtype=dtype, use_safetensors=False
114
- )
115
-
116
- pipeline = DiffusionPipeline.from_pretrained(
117
- ckpt_id,
118
  scheduler=scheduler,
119
- transformer=model,
120
  text_encoder=text_encoder,
121
  text_encoder_2=text_encoder_2,
122
- torch_dtype=dtype,
123
  vae=vae,
124
- load_in_8bit=True,
125
- ).to(device)
126
- # pipeline.vae = torch.compile(pipeline.vae, mode="reduce-overhead")
127
- pipeline.vae.to(memory_format=torch.channels_last)
128
- pipeline.text_encoder.to(memory_format=torch.channels_last)
129
- pipeline.text_encoder_2.to(memory_format=torch.channels_last)
130
- pipeline.transformer.to(memory_format=torch.channels_last)
131
-
132
- for _ in range(2):
133
- pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
134
-
135
- empty_cache()
136
- return pipeline
 
 
 
 
137
 
 
 
138
 
139
  @torch.inference_mode()
140
- def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
 
141
  generator = Generator(pipeline.device).manual_seed(request.seed)
142
- image=pipeline(request.prompt,
143
- generator=generator,
144
- guidance_scale=0.0,
145
- num_inference_steps=4,
146
- max_sequence_length=256,
147
- height=request.height,
148
- width=request.width,
149
- output_type="pt").images[0]
150
-
151
- return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import torch
5
+ import torchvision
6
+ import torch.nn as nn
7
+ from torch import Generator
8
+ from diffusers import (
9
+ FluxPipeline,
10
+ AutoencoderKL,
11
+ AutoencoderTiny,
12
+ DiffusionPipeline,
13
+ FluxTransformer2DModel
14
+ )
15
  from diffusers.image_processor import VaeImageProcessor
16
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ CLIPTextModel
20
+ )
21
  from PIL import Image as img
 
22
  from pipelines.models import TextToImageRequest
 
 
 
 
 
23
  from model import Encoder, Decoder
 
 
 
 
 
 
24
 
25
+ # Environment configuration
26
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
27
  torch.backends.cudnn.benchmark = True
28
  torch.backends.cuda.matmul.allow_tf32 = True
29
  torch.cuda.set_per_process_memory_fraction(0.95)
30
+
31
+ # Constants
32
+ CKPT_ID = "black-forest-labs/FLUX.1-schnell"
33
+
34
+ # Utility functions
35
+ def clear():
36
+ gc.collect()
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.reset_max_memory_allocated()
39
+ torch.cuda.reset_peak_memory_stats()
40
+
41
+ # Quantization classes
42
  class BasicQuantization:
43
  def __init__(self, bits=1):
44
  self.bits = bits
45
+ self.qmin = -(2 ** (bits - 1))
46
+ self.qmax = 2 ** (bits - 1) - 1
47
 
48
  def quantize_tensor(self, tensor):
49
  scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
 
59
 
60
  def quantize_model(self):
61
  for name, module in self.model.named_modules():
62
+ if isinstance(module, nn.Linear):
63
+ if hasattr(module, 'weight'):
64
  quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
65
+ module.weight = nn.Parameter(quantized_weight)
66
  if hasattr(module, 'bias') and module.bias is not None:
67
  quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
68
+ module.bias = nn.Parameter(quantized_bias)
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Pipeline loading
71
+ def load_pipeline():
72
+ """Loads and prepares the Diffusion pipeline."""
73
+ clear()
74
  dtype, device = torch.bfloat16, "cuda"
75
 
76
+ # Load VAE with custom encoder/decoder
77
+ vae = AutoencoderTiny.from_pretrained("manbeast3b/flux.1-schnell-vae-quant1", torch_dtype=dtype)
78
  vae.encoder = Encoder(16)
79
  vae.decoder = Decoder(16)
80
 
81
+ # Load encoder and decoder state dicts
82
  encoder_path = "encoder.pth"
83
  decoder_path = "decoder.pth"
84
 
85
+ if encoder_path:
86
+ encoder_state_dict = torch.load(encoder_path, map_location="cpu")
87
+ filtered_state_dict = {
88
+ k.replace('encoder.', ''): v
89
+ for k, v in encoder_state_dict.items()
90
+ if k.replace('encoder.', '') in vae.encoder.state_dict()
91
+ }
92
  vae.encoder.load_state_dict(filtered_state_dict, strict=False)
93
  vae.encoder.to(dtype=dtype)
94
+
95
+ if decoder_path:
96
+ decoder_state_dict = torch.load(decoder_path, map_location="cpu")
97
+ filtered_state_dict = {
98
+ k.replace('decoder.', ''): v
99
+ for k, v in decoder_state_dict.items()
100
+ if k.replace('decoder.', '') in vae.decoder.state_dict()
101
+ }
102
  vae.decoder.load_state_dict(filtered_state_dict, strict=False)
103
  vae.decoder.to(dtype=dtype)
104
 
 
105
  vae.encoder.requires_grad_(False)
106
+ vae.decoder.requires_grad_(False)
107
+
108
+ # Quantize model
109
  quantizer = ModelQuantization(vae)
110
  quantizer.quantize_model()
111
 
112
+ # Scheduler and text encoders
113
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CKPT_ID, subfolder="scheduler")
114
+ text_encoder = CLIPTextModel.from_pretrained(CKPT_ID, subfolder="text_encoder", torch_dtype=dtype)
115
+ text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype)
116
 
117
+ # Transformer model
118
+ transformer_model = FluxTransformer2DModel.from_pretrained(
119
+ "/root/.cache/huggingface/hub/models--manbeast3b--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a",
120
+ torch_dtype=dtype,
121
+ use_safetensors=False
 
 
 
 
 
 
122
  )
123
 
124
+ # pipeline
125
+ pipeline = DiffusionPipeline(
 
 
 
 
126
  scheduler=scheduler,
127
+ transformer=transformer_model,
128
  text_encoder=text_encoder,
129
  text_encoder_2=text_encoder_2,
 
130
  vae=vae,
131
+ torch_dtype=dtype,
132
+ load_in_8bit=True
133
+ ).to(device)
134
+
135
+ # Optimize memory format
136
+ for component in [pipeline.vae, pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer]:
137
+ component.to(memory_format=torch.channels_last)
138
+
139
+ # Warm-up inference
140
+ pipeline(
141
+ prompt="modificator, drupaceous, jobbernowl, hereness",
142
+ width=1024,
143
+ height=1024,
144
+ guidance_scale=0.0,
145
+ num_inference_steps=4,
146
+ max_sequence_length=256
147
+ )
148
 
149
+ clear()
150
+ return pipeline
151
 
152
  @torch.inference_mode()
153
+ def infer(request: TextToImageRequest, pipeline):
154
+ """Generates an image based on the given request."""
155
  generator = Generator(pipeline.device).manual_seed(request.seed)
156
+ image = pipeline(
157
+ request.prompt,
158
+ generator=generator,
159
+ guidance_scale=0.0,
160
+ num_inference_steps=4,
161
+ max_sequence_length=256,
162
+ height=request.height,
163
+ width=request.width,
164
+ output_type="pt"
165
+ ).images[0]
166
+ return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))