File size: 3,865 Bytes
9cfa421
 
 
0899c61
 
9cfa421
0899c61
9cfa421
0899c61
 
 
9cfa421
 
 
 
0899c61
c62559e
a80c523
9cfa421
0899c61
9cfa421
0899c61
 
 
 
9cfa421
0899c61
 
 
9cfa421
0899c61
9cfa421
0899c61
a80c523
9cfa421
 
 
a80c523
9cfa421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f5fd4a
 
 
 
 
 
 
 
0899c61
9cfa421
97b5a44
 
 
 
 
 
 
 
0899c61
97b5a44
 
0899c61
97b5a44
0899c61
65f0854
0899c61
9cfa421
 
 
 
0899c61
9cfa421
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler

from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
from diffusers.image_processor import VaeImageProcessor
Pipeline = None
import os
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
traced_vae_decode_path = "traced_vae_decode.pt"  
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")



def load_pipeline() -> Pipeline:    
    empty_cache()
    dtype, device = torch.bfloat16, "cuda"
    vae = AutoencoderKL.from_pretrained(
        MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16
    )
    quantize_(vae, int8_weight_only())
    pipeline = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        vae=vae,
        torch_dtype=dtype,
        )
        
    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        empty_cache()
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline

def trace_and_save_vae_decoder(vae, latents):
    
    # import sys
    # sys.exit(1)
    try:
        traced_vae_decode = torch.jit.trace(vae.decode, (latents, True))
        torch.jit.save(traced_vae_decode, traced_vae_decode_path)
        return traced_vae_decode
    except Exception as e:
        print(f"JIT tracing failed: {e}")
        return vae.decode #Fall back to untraced decoder.
        
def decode_latents_to_image(latents, height: int, width: int, vae):
    if not height:
        height = 1024
    if not width:
        width = 1024
    if vae.config.block_out_channels:
        vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
    else:
        vae_scale_factor = 1
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
    
    #  # Try to load the traced model; trace and save if not found
    # if os.path.exists(traced_vae_decode_path):
    #     try:
    #         traced_vae_decode = torch.jit.load(traced_vae_decode_path)
    #         # print("Loaded traced VAE decoder from file.")
    #     except Exception as e:
    #         # print(f"Error loading traced VAE decoder: {e}. Retracing...")
    #         traced_vae_decode = trace_and_save_vae_decoder(vae, latents)

    # else:
    #     traced_vae_decode = trace_and_save_vae_decoder(vae, latents)

    traced_vae_decode = vae.decode
    with torch.no_grad():
        latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
        latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
        image = traced_vae_decode(latents, return_dict=False)[0]  # Use the traced function
        decoded_image = image_processor.postprocess(image, output_type="pil")[0]
        
    return decoded_image


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    empty_cache()
    generator = Generator("cuda").manual_seed(request.seed)
    latent=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="latent").images[0]
    return decode_latents_to_image(latent, request.height, request.width, pipeline.vae)