File size: 3,278 Bytes
c37a2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler

from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
from diffusers.image_processor import VaeImageProcessor
Pipeline = None
import os
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
traced_vae_decode_path = "traced_vae_decode.pt"  
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")



def load_pipeline() -> Pipeline:    
    empty_cache()
    dtype, device = torch.bfloat16, "cuda"
    vae = AutoencoderKL.from_pretrained(
        MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16
    )
    quantize_(vae, int8_weight_only())
    pipeline = DiffusionPipeline.from_pretrained(
        MODEL_ID,
        vae=vae,
        torch_dtype=dtype,
        )
        
    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        empty_cache()
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline

def trace_and_save_vae_decoder(vae, latents):
    try:
        traced_vae_decode = torch.jit.trace(vae.decode, (latents, True))
        torch.jit.save(traced_vae_decode, traced_vae_decode_path)
        return traced_vae_decode
    except Exception as e:
        print(f"JIT tracing failed: {e}")
        return vae.decode #Fall back to untraced decoder.
        
def decode_latents_to_image(latents, height: int, width: int, vae):
    if not height:
        height = 1024
    if not width:
        width = 1024
    if vae.config.block_out_channels:
        vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
    else:
        vae_scale_factor = 1
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

    traced_vae_decode = vae.decode
    with torch.no_grad():
        latents = FluxPipeline._unpack_latents(latents.unsqueeze(0), height, width, vae_scale_factor)
        latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
        image = traced_vae_decode(latents, return_dict=False)[0]  # Use the traced function
        decoded_image = image_processor.postprocess(image, output_type="pil")[0]
        
    return decoded_image


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    empty_cache()
    generator = Generator("cuda").manual_seed(request.seed)
    latent=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="latent").images[0]
    return decode_latents_to_image(latent, request.height, request.width, pipeline.vae)