slobers commited on
Commit
b9ee776
·
verified ·
1 Parent(s): 448be19

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +13 -2
src/pipeline.py CHANGED
@@ -15,19 +15,30 @@ from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
15
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
16
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
17
  torch._dynamo.config.suppress_errors = True
18
-
 
 
19
  Pipeline = None
20
  ids = "black-forest-labs/FLUX.1-schnell"
21
  Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
 
 
 
 
22
 
23
  def load_pipeline() -> Pipeline:
 
24
  vae = AutoencoderTiny.from_pretrained("slobers/tt1",revision="ec746bf42d91e3335760895281f070df54f2196a", torch_dtype=torch.bfloat16,)
25
  text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
26
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
27
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
28
  pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
29
  pipeline.to("cuda")
 
 
30
 
 
31
  for _ in range(3):
32
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
33
  return pipeline
@@ -35,7 +46,7 @@ def load_pipeline() -> Pipeline:
35
  @torch.no_grad()
36
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
37
  generator = Generator(pipeline.device).manual_seed(request.seed)
38
-
39
  return pipeline(
40
  request.prompt,
41
  generator=generator,
 
15
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
16
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
17
  torch._dynamo.config.suppress_errors = True
18
+ torch.backends.cudnn.benchmark = True
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.cuda.set_per_process_memory_fraction(0.95)
21
  Pipeline = None
22
  ids = "black-forest-labs/FLUX.1-schnell"
23
  Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
24
+ def empty_cache():
25
+ gc.collect()
26
+ torch.cuda.empty_cache()
27
+ torch.cuda.reset_max_memory_allocated()
28
+ torch.cuda.reset_peak_memory_stats()
29
 
30
  def load_pipeline() -> Pipeline:
31
+ empty_cache()
32
  vae = AutoencoderTiny.from_pretrained("slobers/tt1",revision="ec746bf42d91e3335760895281f070df54f2196a", torch_dtype=torch.bfloat16,)
33
  text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
34
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
35
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
36
  pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
37
  pipeline.to("cuda")
38
+ pipeline.vae.enable_tiling()
39
+ pipeline.vae.enable_slicing()
40
 
41
+ empty_cache()
42
  for _ in range(3):
43
  pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
44
  return pipeline
 
46
  @torch.no_grad()
47
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
48
  generator = Generator(pipeline.device).manual_seed(request.seed)
49
+ empty_cache()
50
  return pipeline(
51
  request.prompt,
52
  generator=generator,