slobers commited on
Commit
c01aef2
·
verified ·
1 Parent(s): 7265477

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +25 -32
src/pipeline.py CHANGED
@@ -1,47 +1,40 @@
 
 
 
 
 
1
  import gc
2
  import os
3
- from typing import TypeAlias
4
-
5
- import torch
6
  from PIL.Image import Image
7
- from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny
8
- from huggingface_hub.constants import HF_HUB_CACHE
9
  from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
- from transformers import T5EncoderModel, CLIPTextModel
12
-
13
- Pipeline: TypeAlias = FluxPipeline
14
 
15
- CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
16
- REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
17
 
 
 
 
18
 
19
  def load_pipeline() -> Pipeline:
20
- text_encoder = CLIPTextModel.from_pretrained(CHECKPOINT, revision=REVISION, subfolder="text_encoder", local_files_only=True, torch_dtype=torch.bfloat16,)
21
-
22
- path2 = os.path.join(HF_HUB_CACHE, "models--city96--t5-v1_1-xxl-encoder-bf16/snapshots/1b9c856aadb864af93c1dcdc226c2774fa67bc86")
23
-
24
- text_encoder_2 = T5EncoderModel.from_pretrained(path2, torch_dtype=torch.bfloat16,)
25
-
26
- pathV = os.path.join(HF_HUB_CACHE, "models--madebyollin--taef1/snapshots/5463ee684fd9131a724bea777a2f50d89b0b6b24")
27
-
28
- vae = AutoencoderTiny.from_pretrained(pathV, torch_dtype=torch.bfloat16,)
29
-
30
- pathT = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
31
-
32
- transformer = FluxTransformer2DModel.from_pretrained(pathT, torch_dtype=torch.bfloat16, use_safetensors=False,)
33
-
34
- pipeline = FluxPipeline.from_pretrained(CHECKPOINT, revision=REVISION, local_files_only=True, text_encoder=text_encoder, text_encoder_2=text_encoder_2, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16,).to("cuda")
35
-
36
- pipeline("")
37
-
38
  return pipeline
39
 
 
40
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
41
- gc.collect()
42
- torch.cuda.empty_cache()
43
- torch.cuda.reset_peak_memory_stats()
44
-
45
  generator = Generator(pipeline.device).manual_seed(request.seed)
46
 
47
  return pipeline(
 
1
+ #2
2
+ from huggingface_hub.constants import HF_HUB_CACHE
3
+ from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
4
+ import torch
5
+ import torch._dynamo
6
  import gc
7
  import os
8
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
 
 
9
  from PIL.Image import Image
 
 
10
  from pipelines.models import TextToImageRequest
11
  from torch import Generator
12
+ from diffusers import FluxTransformer2DModel, DiffusionPipeline
13
+ from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
 
14
 
15
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
17
+ torch._dynamo.config.suppress_errors = True
18
 
19
+ Pipeline = None
20
+ ids = "black-forest-labs/FLUX.1-schnell"
21
+ Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
22
 
23
  def load_pipeline() -> Pipeline:
24
+ vae = AutoencoderKL.from_pretrained(ids,revision=Revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,)
25
+ quantize_(vae, int8_weight_only())
26
+ text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
27
+ path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
28
+ transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
29
+ pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
30
+ pipeline.to("cuda")
31
+
32
+ for _ in range(3):
33
+ pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
 
 
 
 
 
 
 
34
  return pipeline
35
 
36
+ @torch.no_grad()
37
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
 
 
 
 
38
  generator = Generator(pipeline.device).manual_seed(request.seed)
39
 
40
  return pipeline(