TrendForge commited on
Commit
524d6b8
·
verified ·
1 Parent(s): 8e3f1b7

Initial commit with folder contents

Browse files
Files changed (1) hide show
  1. src/pipeline.py +61 -55
src/pipeline.py CHANGED
@@ -1,4 +1,3 @@
1
- # Coding
2
  import os
3
  import torch
4
  import torch._dynamo
@@ -6,31 +5,27 @@ import gc
6
  from PIL.Image import Image
7
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
-
11
-
12
- from PIL.Image import Image
13
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
 
14
  from pipelines.models import TextToImageRequest
15
-
16
- from PIL.Image import Image
17
  from torch import Generator
18
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
19
-
20
 
 
21
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
 
22
  torch._dynamo.config.suppress_errors = True
23
-
24
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
25
-
26
-
27
-
28
  Pipeline = None
 
 
29
  CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
30
  REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
31
 
32
- class QuantativeAnalysis:
33
 
 
34
  def __init__(self, model, num_bins=256, scale_ratio=1.0):
35
  self.model = model
36
  self.num_bins = num_bins
@@ -42,10 +37,12 @@ class QuantativeAnalysis:
42
  with torch.no_grad():
43
  param_min = param.min()
44
  param_max = param.max()
 
45
  if param_range > 0:
46
- params = 0.8*param_min + 0.2*param_max
47
  return self.model
48
 
 
49
  class AttentionQuant:
50
  def __init__(self, model, att_config):
51
  self.model = model
@@ -58,71 +55,80 @@ class AttentionQuant:
58
  if layer_name in self.att_config:
59
  num_bins, scale_factor = self.att_config[layer_name]
60
  with torch.no_grad():
61
- # Normalize weights, apply binning, and rescale
62
  param_min = param.min()
63
  param_max = param.max()
64
  param_range = param_max - param_min
65
-
66
  if param_range > 0:
67
  normalized = (param - param_min) / param_range
68
  binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
69
- rescaled = binned * param_range + param_mins
70
- params.data.copy_(rescaled * scale_factor)
71
  else:
72
- params.data.zero_()
73
-
74
  return self.model
75
 
76
- def load_pipeline() -> Pipeline:
77
-
78
- __t5_model = T5EncoderModel.from_pretrained("TrendForge/extra1manQ1",
79
- revision = "d302b6e39214ed4532be34ec337f93c7eef3eaa6",
80
- torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
81
 
 
 
 
 
 
 
 
82
  __text_encoder_2 = __t5_model
83
 
84
- base_vae = AutoencoderTiny.from_pretrained("TrendForge/extra2manQ2",
85
- revision="cef012d2db2f5a006567e797a0b9130aea5449c1",
86
- torch_dtype=torch.bfloat16)
87
-
 
 
88
 
 
89
  path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4")
90
- base_trans = FluxTransformer2DModel.from_pretrained(path,
91
- torch_dtype=torch.bfloat16,
92
- use_safetensors=False).to(memory_format=torch.channels_last)
93
-
94
  try:
95
  att_config = {
96
  "transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1),
97
  "transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1),
98
  "transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1)
99
  }
100
- transformer = AttentionQuant(transformer, att_config).apply()
101
- except:
102
-
103
  transformer = base_trans
104
 
105
- pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
106
- revision=REVISION,
107
- vae=base_vae,
108
- transformer=transformer,
109
- text_encoder_2=__text_encoder_2,
110
- torch_dtype=torch.bfloat16)
 
 
 
111
  pipeline.to("cuda")
112
 
113
- for _warmup_batch in range(3):
114
- pipeline(prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites",
115
- width=1024,
116
- height=1024,
117
- guidance_scale=0.0,
118
- num_inference_steps=4,
119
- max_sequence_length=256)
 
 
 
 
120
  return pipeline
121
 
 
122
  @torch.no_grad()
123
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
124
  generator = Generator(pipeline.device).manual_seed(request.seed)
125
-
126
  return pipeline(
127
  request.prompt,
128
  generator=generator,
@@ -130,5 +136,5 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
130
  num_inference_steps=4,
131
  max_sequence_length=256,
132
  height=request.height,
133
- width=request.width,
134
- ).images[0]
 
 
1
  import os
2
  import torch
3
  import torch._dynamo
 
5
  from PIL.Image import Image
6
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
7
  from huggingface_hub.constants import HF_HUB_CACHE
8
+ from transformers import (
9
+ T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
+ )
11
+ from diffusers import (
12
+ FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
13
+ )
14
  from pipelines.models import TextToImageRequest
 
 
15
  from torch import Generator
 
 
16
 
17
+ # Set environment variables
18
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
19
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
20
  torch._dynamo.config.suppress_errors = True
 
 
 
 
 
21
  Pipeline = None
22
+
23
+ # Define constants
24
  CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
25
  REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
26
 
 
27
 
28
+ class QuantativeAnalysis:
29
  def __init__(self, model, num_bins=256, scale_ratio=1.0):
30
  self.model = model
31
  self.num_bins = num_bins
 
37
  with torch.no_grad():
38
  param_min = param.min()
39
  param_max = param.max()
40
+ param_range = param_max - param_min
41
  if param_range > 0:
42
+ params = 0.8 * param_min + 0.2 * param_max
43
  return self.model
44
 
45
+
46
  class AttentionQuant:
47
  def __init__(self, model, att_config):
48
  self.model = model
 
55
  if layer_name in self.att_config:
56
  num_bins, scale_factor = self.att_config[layer_name]
57
  with torch.no_grad():
 
58
  param_min = param.min()
59
  param_max = param.max()
60
  param_range = param_max - param_min
61
+
62
  if param_range > 0:
63
  normalized = (param - param_min) / param_range
64
  binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
65
+ rescaled = binned * param_range + param_min
66
+ param.data.copy_(rescaled * scale_factor)
67
  else:
68
+ param.data.zero_()
 
69
  return self.model
70
 
 
 
 
 
 
71
 
72
+ def load_pipeline() -> Pipeline:
73
+ # Load T5 model
74
+ __t5_model = T5EncoderModel.from_pretrained(
75
+ "TrendForge/extra1manQ1",
76
+ revision="d302b6e39214ed4532be34ec337f93c7eef3eaa6",
77
+ torch_dtype=torch.bfloat16
78
+ ).to(memory_format=torch.channels_last)
79
  __text_encoder_2 = __t5_model
80
 
81
+ # Load VAE
82
+ base_vae = AutoencoderTiny.from_pretrained(
83
+ "TrendForge/extra2manQ2",
84
+ revision="cef012d2db2f5a006567e797a0b9130aea5449c1",
85
+ torch_dtype=torch.bfloat16
86
+ )
87
 
88
+ # Load Transformer Model
89
  path = os.path.join(HF_HUB_CACHE, "models--TrendForge--extra0manQ0/snapshots/dc2cda167b8f53792a98020a3ef2f21808b09bb4")
90
+ base_trans = FluxTransformer2DModel.from_pretrained(
91
+ path, torch_dtype=torch.bfloat16, use_safetensors=False
92
+ ).to(memory_format=torch.channels_last)
93
+
94
  try:
95
  att_config = {
96
  "transformer_blocks.15.attn.norm_added_k.weight": (64, 0.1),
97
  "transformer_blocks.15.attn.norm_added_q.weight": (64, 0.1),
98
  "transformer_blocks.15.attn.norm_added_v.weight": (64, 0.1)
99
  }
100
+ transformer = AttentionQuant(base_trans, att_config).apply()
101
+ except Exception:
 
102
  transformer = base_trans
103
 
104
+ # Load pipeline
105
+ pipeline = DiffusionPipeline.from_pretrained(
106
+ CHECKPOINT,
107
+ revision=REVISION,
108
+ vae=base_vae,
109
+ transformer=transformer,
110
+ text_encoder_2=__text_encoder_2,
111
+ torch_dtype=torch.bfloat16
112
+ )
113
  pipeline.to("cuda")
114
 
115
+ # Warmup
116
+ for _ in range(3):
117
+ pipeline(
118
+ prompt="forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites",
119
+ width=1024,
120
+ height=1024,
121
+ guidance_scale=0.0,
122
+ num_inference_steps=4,
123
+ max_sequence_length=256
124
+ )
125
+
126
  return pipeline
127
 
128
+
129
  @torch.no_grad()
130
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
131
  generator = Generator(pipeline.device).manual_seed(request.seed)
 
132
  return pipeline(
133
  request.prompt,
134
  generator=generator,
 
136
  num_inference_steps=4,
137
  max_sequence_length=256,
138
  height=request.height,
139
+ width=request.width
140
+ ).images[0]