sharper740 commited on
Commit
8f9a9c3
·
verified ·
1 Parent(s): a6f3b7c

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. pyproject.toml +40 -0
  3. src/main.py +54 -0
  4. src/pipeline.py +226 -0
  5. uv.lock +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
pyproject.toml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "Optimization"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.5.1",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ ]
23
+
24
+ [[tool.edge-maxxing.models]]
25
+ repository = "black-forest-labs/FLUX.1-schnell"
26
+ revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
27
+
28
+ [[tool.edge-maxxing.models]]
29
+ repository = "city96/t5-v1_1-xxl-encoder-bf16"
30
+ revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
31
+
32
+ [[tool.edge-maxxing.models]]
33
+ repository = "park234/FLUX1-SCHENELL-INT8"
34
+ revision = "59c2f006f045d9ccdc2e3ab02150b8df0adfafc6s"
35
+
36
+
37
+
38
+ [project.scripts]
39
+ start_inference = "main:main"
40
+
src/main.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ from PIL.JpegImagePlugin import JpegImageFile
10
+ from pipelines.models import TextToImageRequest
11
+ from pipeline import load_pipeline, inference
12
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
13
+
14
+
15
+ def at_exit():
16
+ torch.cuda.empty_cache()
17
+
18
+
19
+ def main():
20
+ atexit.register(at_exit)
21
+
22
+ print(f"Loading pipeline")
23
+ pipeline = load_pipeline()
24
+
25
+ print(f"Pipeline loaded, creating socket at '{SOCKET}'")
26
+
27
+ if exists(SOCKET):
28
+ remove(SOCKET)
29
+
30
+ with Listener(SOCKET) as listener:
31
+ chmod(SOCKET, 0o777)
32
+
33
+ print(f"Awaiting connections")
34
+ with listener.accept() as connection:
35
+ print(f"Connected")
36
+ generator = torch.Generator("cuda")
37
+ while True:
38
+ try:
39
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
40
+ except EOFError:
41
+ print(f"Inference socket exiting")
42
+
43
+ return
44
+ image = inference(request, pipeline, generator.manual_seed(request.seed))
45
+ data = BytesIO()
46
+ image.save(data, format=JpegImageFile.format)
47
+
48
+ packet = data.getvalue()
49
+
50
+ connection.send_bytes(packet)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ main()
src/pipeline.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import json
4
+ import math
5
+ from typing import Any, Dict
6
+
7
+ import torch
8
+ from torch import Generator
9
+ import torch._dynamo
10
+
11
+ import transformers
12
+ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
13
+ from huggingface_hub.constants import HF_HUB_CACHE
14
+
15
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel, AutoencoderTiny
16
+ from pipelines.models import TextToImageRequest
17
+
18
+ from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
19
+
20
+ from PIL.Image import Image
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # Environment Configuration & Global Constants
24
+ # -----------------------------------------------------------------------------
25
+ torch._dynamo.config.suppress_errors = True
26
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
27
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
28
+
29
+ # Identifiers for the diffusion model checkpoint.
30
+ MODEL_ID = "black-forest-labs/FLUX.1-schnell"
31
+ MODEL_REV = "741f7c3ce8b383c54771c7003378a50191e9efe9"
32
+
33
+
34
+ # -----------------------------------------------------------------------------
35
+ # Quantization and Linear Transformation Utilities
36
+ # -----------------------------------------------------------------------------
37
+ def perform_linear_quant(
38
+ input_tensor: torch.Tensor,
39
+ weight_tensor: torch.Tensor,
40
+ w_scale: float,
41
+ w_zero: int,
42
+ in_scale: float,
43
+ in_zero: int,
44
+ out_scale: float,
45
+ out_zero: int,
46
+ ) -> torch.Tensor:
47
+ """
48
+ Performs a quantization-aware linear operation on the input tensor.
49
+
50
+ This function first dequantizes both the input and the weights,
51
+ applies a linear transformation, and then requantizes the result.
52
+
53
+ Parameters:
54
+ input_tensor (torch.Tensor): The input tensor.
55
+ weight_tensor (torch.Tensor): The weight tensor.
56
+ w_scale (float): Scale factor for the weights.
57
+ w_zero (int): Zero-point for the weights.
58
+ in_scale (float): Scale factor for the input.
59
+ in_zero (int): Zero-point for the input.
60
+ out_scale (float): Scale factor for the output.
61
+ out_zero (int): Zero-point for the output.
62
+
63
+ Returns:
64
+ torch.Tensor: The quantized output tensor.
65
+ """
66
+ # Convert to float and dequantize
67
+ inp_deq = input_tensor.float() - in_zero
68
+ wt_deq = weight_tensor.float() - w_zero
69
+
70
+ # Standard linear transformation
71
+ lin_result = torch.nn.functional.linear(inp_deq, wt_deq)
72
+
73
+ # Requantize the result
74
+ requantized = lin_result * ((in_scale * w_scale) / out_scale) + out_zero
75
+ return torch.clamp(torch.round(requantized), 0, 255)
76
+
77
+
78
+ # -----------------------------------------------------------------------------
79
+ # Model Initialization Functions
80
+ # -----------------------------------------------------------------------------
81
+ def initialize_text_encoder() -> T5EncoderModel:
82
+ """
83
+ Loads the T5 text encoder and returns it in a channels-last format.
84
+ """
85
+ print("Initializing T5 text encoder...")
86
+ encoder = T5EncoderModel.from_pretrained(
87
+ "city96/t5-v1_1-xxl-encoder-bf16",
88
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
89
+ torch_dtype=torch.bfloat16,
90
+ )
91
+ return encoder.to(memory_format=torch.channels_last)
92
+
93
+
94
+ def initialize_transformer(transformer_dir: str) -> FluxTransformer2DModel:
95
+ """
96
+ Loads the Flux transformer model from a specified directory.
97
+ """
98
+ print("Initializing Flux transformer...")
99
+ transformer = FluxTransformer2DModel.from_pretrained(
100
+ transformer_dir,
101
+ torch_dtype=torch.bfloat16,
102
+ use_safetensors=False,
103
+ )
104
+ return transformer.to(memory_format=torch.channels_last)
105
+
106
+
107
+ # -----------------------------------------------------------------------------
108
+ # Pipeline Construction
109
+ # -----------------------------------------------------------------------------
110
+ def load_pipeline() -> DiffusionPipeline:
111
+ """
112
+ Constructs the diffusion pipeline by combining the text encoder and transformer.
113
+
114
+ This function also applies a dummy quantization operation to the linear
115
+ submodules of the transformer and enables VAE tiling. Finally, it performs
116
+ several warm-up calls to stabilize performance.
117
+
118
+ Returns:
119
+ DiffusionPipeline: The configured diffusion pipeline.
120
+ """
121
+
122
+ # Build the path to the transformer snapshot.
123
+ transformer_dir = os.path.join(
124
+ HF_HUB_CACHE,
125
+ "models--park234--FLUX1-SCHENELL-INT8/snapshots/59c2f006f045d9ccdc2e3ab02150b8df0adfafc6",
126
+ )
127
+ transformer_model = initialize_transformer(transformer_dir)
128
+
129
+ encoder = initialize_text_encoder()
130
+
131
+ pipeline_instance = DiffusionPipeline.from_pretrained(
132
+ MODEL_ID,
133
+ revision=MODEL_REV,
134
+ transformer=transformer_model,
135
+ text_encoder_2=encoder,
136
+ torch_dtype=torch.bfloat16,
137
+ ).to("cuda")
138
+
139
+ try:
140
+ # Process each linear layer in the transformer for quantization adjustments.
141
+ linear_modules = [
142
+ mod for mod in pipeline_instance.transformer.layers
143
+ if "Linear" in mod.__classname__
144
+ ]
145
+ for mod in linear_modules:
146
+ dummy_input = torch.randn(1, 256) # Dummy tensor for demonstration.
147
+ # Perform a dummy quantization adjustment using exponential notation.
148
+ _ = perform_linear_quant(
149
+ input_tensor=dummy_input,
150
+ weight_tensor=mod.weight,
151
+ w_scale=1e-1,
152
+ w_zero=0,
153
+ in_scale=1e-1,
154
+ in_zero=0,
155
+ out_scale=1e-1,
156
+ out_zero=0,
157
+ )
158
+ pipeline_instance.vae.enable_vae_tiling()
159
+ except Exception as err:
160
+ print("Warning: Quantization adjustments or VAE tiling failed:", err)
161
+
162
+ # Run several warm-up inferences.
163
+ warmup_prompt = "unrectangular, uneucharistical, pouchful, uplay, person"
164
+ for _ in range(3):
165
+ _ = pipeline_instance(
166
+ prompt=warmup_prompt,
167
+ width=1024,
168
+ height=1024,
169
+ guidance_scale=0.0,
170
+ num_inference_steps=4,
171
+ max_sequence_length=256,
172
+ )
173
+ return pipeline_instance
174
+
175
+
176
+ # -----------------------------------------------------------------------------
177
+ # Inference Function
178
+ # -----------------------------------------------------------------------------
179
+ @torch.no_grad()
180
+ def inference(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
181
+ """
182
+ Generates an image based on the provided text prompt and image parameters.
183
+
184
+ The function clears the GPU cache, seeds the random generator, and calls the
185
+ diffusion pipeline to produce the output image.
186
+
187
+ Parameters:
188
+ request (TextToImageRequest): Contains prompt, height, width, and seed.
189
+ pipeline (DiffusionPipeline): The diffusion pipeline to run inference.
190
+
191
+ Returns:
192
+ Image: The generated image.
193
+ """
194
+ torch.cuda.empty_cache()
195
+ rnd_gen = Generator(pipeline.device).manual_seed(request.seed)
196
+ output = pipeline(
197
+ request.prompt,
198
+ generator=rnd_gen,
199
+ guidance_scale=0.0,
200
+ num_inference_steps=4,
201
+ max_sequence_length=256,
202
+ height=request.height,
203
+ width=request.width,
204
+ output_type="pil"
205
+ )
206
+ return output.images[0]
207
+
208
+
209
+ # -----------------------------------------------------------------------------
210
+ # Example Main Flow (Optional)
211
+ # -----------------------------------------------------------------------------
212
+ if __name__ == "__main__":
213
+ # Construct the diffusion pipeline.
214
+ diffusion_pipe = load_pipeline()
215
+
216
+ # Create a sample request (assuming TextToImageRequest is appropriately defined).
217
+ sample_request = TextToImageRequest(
218
+ prompt="a scenic view of mountains at sunrise",
219
+ height=512,
220
+ width=512,
221
+ seed=1234
222
+ )
223
+
224
+ # Generate an image.
225
+ result_image = inference(sample_request, diffusion_pipe)
226
+ # Here, you may save or display 'result_image' as desired.
uv.lock ADDED
The diff for this file is too large to render. See raw diff