manbeast3b commited on
Commit
f34ee03
·
verified ·
1 Parent(s): af3affd

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +14 -11
src/pipeline.py CHANGED
@@ -15,7 +15,7 @@ Pipeline = None
15
  # Configure CUDA settings
16
  torch.backends.cudnn.benchmark = True
17
  torch.backends.cuda.matmul.allow_tf32 = True
18
- torch.cuda.set_per_process_memory_fraction(0.999)
19
 
20
  class BasicQuantization:
21
  def __init__(self, bits=1):
@@ -59,14 +59,14 @@ def load_pipeline() -> Pipeline:
59
  quantizer = ModelQuantization(vae)
60
  quantizer.quantize_model()
61
 
62
- text_encoder_2 = T5EncoderModel.from_pretrained(
63
- "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
64
- )
65
 
66
  # Initialize pipeline
67
  pipeline = FluxPipeline.from_pretrained(
68
  ckpt_id,
69
- text_encoder_2=text_encoder_2,
70
  vae=vae,
71
  torch_dtype=dtype
72
  )
@@ -77,17 +77,20 @@ def load_pipeline() -> Pipeline:
77
  component.to(memory_format=torch.channels_last)
78
 
79
  # Compile and configure pipeline
80
- pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, dynamic=False, mode="max-autotune")
81
  pipeline._exclude_from_cpu_offload = ["vae"]
82
  pipeline.enable_sequential_cpu_offload()
83
-
 
 
 
84
  # Warmup run
85
  empty_cache()
86
- for _ in range(3):
87
  pipeline(
88
  prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
89
- width=1024,
90
- height=1024,
91
  guidance_scale=0.0,
92
  num_inference_steps=4,
93
  max_sequence_length=256
@@ -111,7 +114,7 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
111
  empty_cache()
112
  _inference_count = 0
113
 
114
- torch.cuda.reset_peak_memory_stats()
115
  generator = Generator("cuda").manual_seed(request.seed)
116
  return pipeline(
117
  prompt=request.prompt,
 
15
  # Configure CUDA settings
16
  torch.backends.cudnn.benchmark = True
17
  torch.backends.cuda.matmul.allow_tf32 = True
18
+ torch.cuda.set_per_process_memory_fraction(0.99)
19
 
20
  class BasicQuantization:
21
  def __init__(self, bits=1):
 
59
  quantizer = ModelQuantization(vae)
60
  quantizer.quantize_model()
61
 
62
+ # text_encoder_2 = T5EncoderModel.from_pretrained(
63
+ # "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
64
+ # )
65
 
66
  # Initialize pipeline
67
  pipeline = FluxPipeline.from_pretrained(
68
  ckpt_id,
69
+ # text_encoder_2=text_encoder_2,
70
  vae=vae,
71
  torch_dtype=dtype
72
  )
 
77
  component.to(memory_format=torch.channels_last)
78
 
79
  # Compile and configure pipeline
80
+ pipeline.vae = torch.compile(pipe.vae, mode="reduce-overhead")
81
  pipeline._exclude_from_cpu_offload = ["vae"]
82
  pipeline.enable_sequential_cpu_offload()
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ pipeline.vae.encoder.to(device)
85
+ pipeline.vae.decoder.to(device)
86
+
87
  # Warmup run
88
  empty_cache()
89
+ for _ in range(2):
90
  pipeline(
91
  prompt="posteroexternal, eurythmical, inspection, semicotton, specification, Mercatorial, ethylate, misprint",
92
+ width=1480,
93
+ height=1480,
94
  guidance_scale=0.0,
95
  num_inference_steps=4,
96
  max_sequence_length=256
 
114
  empty_cache()
115
  _inference_count = 0
116
 
117
+ # torch.cuda.reset_peak_memory_stats()
118
  generator = Generator("cuda").manual_seed(request.seed)
119
  return pipeline(
120
  prompt=request.prompt,