File size: 2,571 Bytes
3adf119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from diffusers import FluxPipeline, AutoencoderKL, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel, CLIPTextConfig, T5Config
import torch
import gc
from PIL import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from time import perf_counter

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

class EightQuantize:
    def __init__(self, bits=8):
        self.bits = bits
        self.qmax = (1 << bits) - 1
        
    def __call__(self, x):
        scale = x.max() / self.qmax
        x_quant = torch.clip(torch.round(x / scale), 0, self.qmax)
        return x_quant * scale




CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
DTYPE = torch.bfloat16
NUM_STEPS = 4

def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> FluxPipeline:
    empty_cache()
    is_quantize = 0
    _pipe = None
    pipe = FluxPipeline.from_pretrained(CHECKPOINT, torch_dtype=DTYPE)
    
    pipe.text_encoder_2.to(memory_format=torch.channels_last)
    pipe.transformer.to(memory_format=torch.channels_last)
    
    pipe.vae.to(memory_format=torch.channels_last)
    pipe.vae = torch.compile(pipe.vae)
    pipe._exclude_from_cpu_offload = ["vae"]
    
    try:
        if is_quantize:
            quantizer = EightQuantize()
            with torch.no_grad():
                for param in _pipe.vae.parameters():
                    param.data = quantizer(param.data)
    except Exception as e:
        print(f"Quantization warning: {e}")
    
    pipe.enable_sequential_cpu_offload()
    
    empty_cache()
    pipe("dog", guidance_scale=0.0, max_sequence_length=256, num_inference_steps=4)
    return pipe

@torch.inference_mode()
def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
    torch.cuda.reset_peak_memory_stats()
    
    if request.seed is None:
        generator = None
    else:
        generator = Generator(device="cuda").manual_seed(request.seed)

    empty_cache()
    image = _pipeline(prompt=request.prompt,
                      width=request.width,
                      height=request.height,
                      guidance_scale=0.0,
                      generator=generator,
                      output_type="pil",
                      max_sequence_length=256,
                      num_inference_steps=NUM_STEPS).images[0]
    return image