File size: 2,965 Bytes
76bde5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bef444
 
 
 
76bde5a
 
 
 
 
 
 
e80ddda
76bde5a
e80ddda
76bde5a
e80ddda
 
 
a52aa92
e80ddda
76bde5a
e80ddda
 
 
 
 
a52aa92
e80ddda
 
 
 
 
 
5c7fb1a
e80ddda
 
 
76bde5a
 
 
 
 
47c0059
76bde5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
torch._dynamo.config.suppress_errors = True

Pipeline = None

# ckpt_id = "manbeast3b/flux.1-schnell-full1"
# ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146" 
ckpt_id = "black-forest-labs/FLUX.1-schnell"
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9" 
def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:    
    empty_cache()

    dtype, device = torch.bfloat16, "cuda"
    
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16
    ).to(memory_format=torch.channels_last)
    tinypath= os.path.join(HF_HUB_CACHE, "models--madebyollin--taef1/snapshots/5463ee684fd9131a724bea777a2f50d89b0b6b24")
    vae = AutoencoderTiny.from_pretrained(tinypath, torch_dtype=dtype)    
    path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
    model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
    pipeline = FluxPipeline.from_pretrained(
        ckpt_id,
        vae=vae,
        revision=ckpt_revision,
        transformer=model,
        # text_encoder_2=text_encoder_2,
        torch_dtype=dtype,
        ).to(device)
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
    quantize_(pipeline.text_encoder, int8_weight_only())
    quantize_(pipeline.vae, int8_weight_only())
    for _ in range(3):
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    empty_cache()
    return pipeline


@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)