BrenL commited on
Commit
ce03777
·
verified ·
1 Parent(s): 97106df

Initial commit with folder contents

Browse files
Files changed (1) hide show
  1. src/pipeline.py +79 -30
src/pipeline.py CHANGED
@@ -1,49 +1,99 @@
 
1
  import torch
2
  import torch._dynamo
3
- import gc
4
- import os
5
-
6
- from huggingface_hub.constants import HF_HUB_CACHE
7
- from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
8
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
9
  from PIL.Image import Image
 
 
 
 
 
 
 
10
  from pipelines.models import TextToImageRequest
11
- from torch import Generator
12
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
13
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
14
 
15
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
 
16
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
17
  torch._dynamo.config.suppress_errors = True
18
 
19
- Pipeline = None
20
- ids = "black-forest-labs/FLUX.1-schnell"
21
- Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
 
 
 
 
22
 
23
- def load_pipeline() -> Pipeline:
24
- ttimagemodel = "BrenL/extra1IMOO1"
25
- ttimagerevision = "3e33f01cda8a8c207218c2d31853fdc08bebd38f"
26
 
27
- vae = AutoencoderKL.from_pretrained(ids,revision=Revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,)
 
 
 
 
 
 
 
 
 
 
 
28
  quantize_(vae, int8_weight_only())
29
- text_encoder_2 = T5EncoderModel.from_pretrained("BrenL/extra2IMOO2", revision = "f7538acf69d8b71458542b22257de6508850ab6d", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
30
- path = os.path.join(HF_HUB_CACHE, "models--BrenL--extra0IMOO0/snapshots/422ee1f0f85ef1b035f00449540b254df85cd3a6")
31
- transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
32
- pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  pipeline.to("cuda")
34
 
 
35
  for _ in range(2):
36
- pipeline(prompt="satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
 
 
 
 
 
 
 
37
  return pipeline
38
 
 
39
  @torch.no_grad()
40
- def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
41
- generator = Generator(pipeline.device).manual_seed(request.seed)
42
-
43
- try:
44
- prompt = request.prompt
45
- except Exception as e:
46
- prompt = "satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper"
 
 
 
 
 
 
 
47
 
48
  return pipeline(
49
  prompt,
@@ -54,4 +104,3 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
54
  height=request.height,
55
  width=request.width,
56
  ).images[0]
57
-
 
1
+ import os
2
  import torch
3
  import torch._dynamo
 
 
 
 
 
 
4
  from PIL.Image import Image
5
+ from huggingface_hub.constants import HF_HUB_CACHE
6
+ from transformers import T5EncoderModel
7
+ from diffusers import (
8
+ AutoencoderKL,
9
+ DiffusionPipeline,
10
+ FluxTransformer2DModel,
11
+ )
12
  from pipelines.models import TextToImageRequest
13
+ from torchao.quantization import quantize_, int8_weight_only
 
 
14
 
15
+ # Environment setup
16
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
17
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
18
  torch._dynamo.config.suppress_errors = True
19
 
20
+ # Constants
21
+ IDS = "black-forest-labs/FLUX.1-schnell"
22
+ REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
23
+ TT_IMAGE_MODEL = "BrenL/extra1IMOO1"
24
+ TT_IMAGE_REVISION = "3e33f01cda8a8c207218c2d31853fdc08bebd38f"
25
+ EXTRA_TEXT_ENCODER = "BrenL/extra2IMOO2"
26
+ EXTRA_TEXT_REVISION = "f7538acf69d8b71458542b22257de6508850ab6d"
27
+ DEFAULT_PROMPT = "satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper"
28
 
 
 
 
29
 
30
+ def load_pipeline() -> DiffusionPipeline:
31
+ """
32
+ Load and prepare the diffusion pipeline with quantization and required components.
33
+ """
34
+ # Load components
35
+ vae = AutoencoderKL.from_pretrained(
36
+ IDS,
37
+ revision=REVISION,
38
+ subfolder="vae",
39
+ local_files_only=True,
40
+ torch_dtype=torch.bfloat16,
41
+ )
42
  quantize_(vae, int8_weight_only())
43
+
44
+ text_encoder_2 = T5EncoderModel.from_pretrained(
45
+ EXTRA_TEXT_ENCODER,
46
+ revision=EXTRA_TEXT_REVISION,
47
+ torch_dtype=torch.bfloat16,
48
+ ).to(memory_format=torch.channels_last)
49
+
50
+ transformer_path = os.path.join(
51
+ HF_HUB_CACHE,
52
+ "models--BrenL--extra0IMOO0/snapshots/422ee1f0f85ef1b035f00449540b254df85cd3a6",
53
+ )
54
+ transformer = FluxTransformer2DModel.from_pretrained(
55
+ transformer_path, torch_dtype=torch.bfloat16, use_safetensors=False
56
+ ).to(memory_format=torch.channels_last)
57
+
58
+ # Build pipeline
59
+ pipeline = DiffusionPipeline.from_pretrained(
60
+ IDS,
61
+ revision=REVISION,
62
+ transformer=transformer,
63
+ text_encoder_2=text_encoder_2,
64
+ torch_dtype=torch.bfloat16,
65
+ )
66
  pipeline.to("cuda")
67
 
68
+ # Warm-up
69
  for _ in range(2):
70
+ pipeline(
71
+ prompt=DEFAULT_PROMPT,
72
+ width=1024,
73
+ height=1024,
74
+ guidance_scale=0.0,
75
+ num_inference_steps=4,
76
+ max_sequence_length=256,
77
+ )
78
+
79
  return pipeline
80
 
81
+
82
  @torch.no_grad()
83
+ def infer(request: TextToImageRequest, pipeline: DiffusionPipeline) -> Image:
84
+ """
85
+ Perform inference using the diffusion pipeline.
86
+
87
+ Args:
88
+ request (TextToImageRequest): The input request containing parameters like prompt, seed, height, and width.
89
+ pipeline (DiffusionPipeline): The diffusion pipeline to use for inference.
90
+
91
+ Returns:
92
+ Image: Generated image.
93
+ """
94
+ generator = torch.Generator(pipeline.device).manual_seed(request.seed)
95
+
96
+ prompt = request.prompt if hasattr(request, "prompt") else DEFAULT_PROMPT
97
 
98
  return pipeline(
99
  prompt,
 
104
  height=request.height,
105
  width=request.width,
106
  ).images[0]