File size: 2,261 Bytes
7f92df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from diffusers import FluxPipeline, AutoencoderTiny
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import DiffusionPipeline
#from torchao.quantization import quantize_, fpx_weight_only, int8_weight_only
Pipeline = None

ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")

def load_pipeline() -> Pipeline:    
    empty_cache()

    dtype, device = torch.bfloat16, "cuda"

    vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", torch_dtype=dtype)

    ############ Text Encoder ############
    text_encoder = CLIPTextModel.from_pretrained(
        ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
    )
    ############ Text Encoder 2 ############
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=torch.bfloat16
    )

    empty_cache()

    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_2,
        vae=vae,
        torch_dtype=dtype,
        )
    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        gc.collect()
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    gc.collect()
    try:
        generator = Generator("cuda").manual_seed(request.seed)
        image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    except:
        image = img.open("./RobertML.png")
        pass
    return(image)