File size: 3,894 Bytes
372586c
 
 
daa1640
372586c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58b1cf8
61d3add
 
 
 
 
 
 
58b1cf8
372586c
 
 
 
 
 
 
 
 
4318f6f
 
 
 
372586c
 
 
 
2c03636
 
 
ab415dd
 
 
61d3add
 
 
 
ddfef1f
61d3add
 
7d276c3
61d3add
 
7d276c3
61d3add
 
372586c
 
61d3add
 
 
372586c
 
 
 
 
 
7d276c3
 
61d3add
372586c
 
 
 
 
 
 
 
f7d393e
372586c
 
 
 
 
 
 
b08a178
 
372586c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import BitsAndBytesConfig
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_,int8_weight_only
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
Pipeline = None

# Define the quantization config
# nf4_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16
# )
config = BitsAndBytesConfig(load_in_8bit=True)

ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")


cache_dir = "/root/.cache/huggingface/hub/models--manbeast3b--flux-schnell-int8/snapshots/eb656b7968de3088ccac7cda876f5782e5a2f721/"


def load_pipeline() -> Pipeline:    
    empty_cache()
    dtype, device = torch.bfloat16, "cuda"

    # text_encoder_2 = T5EncoderModel.from_pretrained(
    #     "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
    # )
    # text_encoder_2 = T5EncoderModel.from_pretrained(
    #     "sayakpaul/flux.1-dev-nf4-pkg", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
    # )
    # text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")

    model_id = "manbeast3b/flux-schnell-int8"
    transformer = FluxTransformer2DModel.from_pretrained(
        cache_dir, subfolder="transformer", torch_dtype=torch.bfloat16, quantization_config=config
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        cache_dir, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, quantization_config=config
    )
    text_encoder = CLIPTextModel.from_pretrained(
        cache_dir, subfolder="text_encoder",torch_dtype=torch.bfloat16, quantization_config=config
    )
    # vae=AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype)
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id,
        # vae=vae,
        transformer = transformer,
        text_encoder = text_encoder,
        text_encoder_2 = text_encoder_2,
        torch_dtype=dtype,
        )
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.cuda.set_per_process_memory_fraction(0.95)
    # pipeline.text_encoder.to(memory_format=torch.channels_last)
    # pipeline.transformer.to(memory_format=torch.channels_last)
    # torch.jit.enable_onednn_fusion(True)
    

    pipeline.vae.to(memory_format=torch.channels_last)
    pipeline.vae = torch.compile(pipeline.vae)
    
    pipeline._exclude_from_cpu_offload = ["vae"]
    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        pipeline(prompt="warmup run testing one two three", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    torch.cuda.reset_peak_memory_stats()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)