File size: 5,365 Bytes
c864951
2cc59cd
 
 
 
 
c864951
 
 
 
 
2cc59cd
 
 
 
 
 
 
 
 
 
c864951
 
2cc59cd
 
 
 
 
 
c864951
 
2cc59cd
 
 
c864951
2cc59cd
 
 
 
 
c864951
 
 
 
 
 
 
 
 
 
 
2cc59cd
 
 
 
 
 
 
 
 
 
 
e3a3cdd
2cc59cd
 
 
 
 
 
c864951
2cc59cd
 
c864951
 
2cc59cd
 
 
 
c864951
 
 
 
 
 
 
55b23d4
2cc59cd
c864951
2cc59cd
 
 
 
 
 
 
 
 
 
c864951
2cc59cd
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
import os
import torch
import torch._dynamo
import gc

torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"


from huggingface_hub.constants import HF_HUB_CACHE
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline

from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json
import transformers
from functools import wraps



torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"

MAIN_ID = "RichardWilliam/FullyFLUXSCH"
REV = "c5f4f70c6cb9228a9c258799aadc660dde417af6"
Pipeline = None
apply_quanto=1

def to_hell():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def error_handler(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(f"Error in {func.__name__}: {str(e)}")
            return None
    return wrapper

@error_handler
def load_quanto_text_encoder_2(text_repo_path):
    with open("quantization_map.json", "r") as f:
        quantization_map = json.load(f)
    with open(os.path.join(text_repo_path, "config.json"), "r") as f:
        t5_config = transformers.T5Config(**json.load(f))
    with torch.device("meta"):
        text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.bfloat16)
    state_dict = None
    requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
    return text_encoder_2


def load_pipeline() -> Pipeline:

    main_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f")
    origin_trans = FluxTransformer2DModel.from_pretrained(main_path,
                        torch_dtype=torch.bfloat16, 
                        use_safetensors=False).to(memory_format=torch.channels_last)
    
    transformer = origin_trans

    pipeline = DiffusionPipeline.from_pretrained(MAIN_ID, 
                        revision=REV,
                        transformer=transformer, 
                        torch_dtype=torch.bfloat16)
    pipeline.to("cuda")


    text_encoder_v2 = load_quanto_text_encoder_2(text_repo_path=None)

    if text_encoder_v2==None:
        print("Something wrong")
    else:
        pipeline.text_encoder_2 = text_encoder_v2

    for __ in range(3):
        pipeline(prompt="I am the worst", 
                        width=1024, 
                        height=1024, 
                        guidance_scale=0.0, 
                        num_inference_steps=4, 
                        max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:

    to_hell()

    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]