silencer107 commited on
Commit
abbf98b
·
verified ·
1 Parent(s): 165b1c7

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +5 -2
src/pipeline.py CHANGED
@@ -1,4 +1,6 @@
1
  from torch import Generator
 
 
2
  import torch
3
  from PIL.Image import Image
4
  from pipelines.models import TextToImageRequest
@@ -9,8 +11,10 @@ import os
9
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
10
  import torch._dynamo
11
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
 
12
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
13
  HOME = os.environ["HOME"]
 
14
  Pipeline = None
15
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
16
 
@@ -22,10 +26,9 @@ def empty_cache():
22
 
23
  def load_pipeline() -> Pipeline:
24
  empty_cache()
25
- dtype, device = torch.bfloat16, "cuda"
26
  text_encoder = CLIPTextModel.from_pretrained(ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
27
  quantize_(text_encoder, int8_weight_only())
28
- vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", torch_dtype=torch.bfloat16)
29
  quantize_(vae, int8_weight_only())
30
  text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16)
31
  quantize_(text_encoder_2, int8_weight_only())
 
1
  from torch import Generator
2
+ from diffusers.image_processor import VaeImageProcessor
3
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
4
  import torch
5
  from PIL.Image import Image
6
  from pipelines.models import TextToImageRequest
 
11
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
12
  import torch._dynamo
13
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
14
+
15
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
16
  HOME = os.environ["HOME"]
17
+
18
  Pipeline = None
19
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
20
 
 
26
 
27
  def load_pipeline() -> Pipeline:
28
  empty_cache()
 
29
  text_encoder = CLIPTextModel.from_pretrained(ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
30
  quantize_(text_encoder, int8_weight_only())
31
+ vae = AutoencoderTiny.from_pretrained("aifeifei798/taef1", torch_dtype=torch.bfloat16)
32
  quantize_(vae, int8_weight_only())
33
  text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16)
34
  quantize_(text_encoder_2, int8_weight_only())