slobers commited on
Commit
a824051
·
verified ·
1 Parent(s): f6d42fb

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +8 -21
src/pipeline.py CHANGED
@@ -1,51 +1,38 @@
1
- #5.1
2
  from huggingface_hub.constants import HF_HUB_CACHE
3
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
4
  import torch
5
  import torch._dynamo
6
  import gc
7
  import os
8
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
9
  from PIL.Image import Image
10
  from pipelines.models import TextToImageRequest
11
  from torch import Generator
12
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
13
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
14
-
15
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
16
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
17
  torch._dynamo.config.suppress_errors = True
18
-
19
  Pipeline = None
20
  ids = "slobers/Flux.1.Schnella"
21
  Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
22
-
23
- def empty_cache():
24
- gc.collect()
25
- torch.cuda.empty_cache()
26
- torch.cuda.reset_max_memory_allocated()
27
- torch.cuda.reset_peak_memory_stats()
28
-
29
  def load_pipeline() -> Pipeline:
30
- empty_cache()
31
  path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
32
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
33
- pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
 
34
  pipeline.to("cuda")
35
- pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True, dynamic=True)
36
- for _ in range(3):
37
- pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=5.0, num_inference_steps=4, max_sequence_length=256)
38
  return pipeline
39
- empty_cache()
40
-
41
  @torch.no_grad()
42
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
43
  generator = Generator(pipeline.device).manual_seed(request.seed)
44
-
45
  return pipeline(
46
  request.prompt,
47
  generator=generator,
48
- guidance_scale=6.5,
49
  num_inference_steps=4,
50
  max_sequence_length=256,
51
  height=request.height,
 
1
+ #5.2
2
  from huggingface_hub.constants import HF_HUB_CACHE
3
  from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
4
  import torch
5
  import torch._dynamo
6
  import gc
7
  import os
8
+ from diffusers import FluxPipeline, AutoencoderKL
9
  from PIL.Image import Image
10
  from pipelines.models import TextToImageRequest
11
  from torch import Generator
12
  from diffusers import FluxTransformer2DModel, DiffusionPipeline
13
+ from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
14
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
15
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
16
  torch._dynamo.config.suppress_errors = True
 
17
  Pipeline = None
18
  ids = "slobers/Flux.1.Schnella"
19
  Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
 
 
 
 
 
 
 
20
  def load_pipeline() -> Pipeline:
 
21
  path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
22
  transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
23
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", subfolder="vae", torch_dtype=torch.bfloat16)
24
+ pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, vae=vae, local_files_only=True, torch_dtype=torch.bfloat16)
25
  pipeline.to("cuda")
26
+ pipeline = apply_cache_on_pipe(pipeline, residual_diff_threshold=0.888)
27
+ pipeline("")
 
28
  return pipeline
 
 
29
  @torch.no_grad()
30
  def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
31
  generator = Generator(pipeline.device).manual_seed(request.seed)
 
32
  return pipeline(
33
  request.prompt,
34
  generator=generator,
35
+ guidance_scale=0.0,
36
  num_inference_steps=4,
37
  max_sequence_length=256,
38
  height=request.height,