File size: 6,763 Bytes
8360f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19f6187
8360f1c
 
 
19f6187
b14280a
19f6187
8360f1c
19f6187
 
 
8360f1c
 
 
ae29cdb
 
09b7ba9
ae29cdb
8360f1c
19f6187
8360f1c
 
 
 
19f6187
 
 
 
8360f1c
19f6187
ae29cdb
 
b14280a
8360f1c
 
b14280a
8360f1c
 
b14280a
8360f1c
b14280a
19f6187
ae29cdb
30cfd95
 
5f944f0
8360f1c
 
 
 
b14280a
 
8360f1c
b14280a
 
8360f1c
 
 
 
2b16b85
 
 
8360f1c
 
2b16b85
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# from diffusers import AutoencoderKL, FluxTransformer2DModel, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
# from diffusers.image_processor import VaeImageProcessor
# import torch
# import torch._dynamo
# import gc
# import os
# from PIL.Image import Image
# from pipelines.models import TextToImageRequest
# from torch import Generator
# from diffusers import DiffusionPipeline
# from torchao.quantization import quant_api
# # from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
# from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
# from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
# from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
# # from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
# from torchao.quantization.quant_api import PerTensor
# from torchao.quantization import quantize_, float8_weight_only

# HOME = os.environ["HOME"]
# Pipeline = None
# MODEL_ID = "black-forest-labs/FLUX.1-schnell"
# def clear():
#     gc.collect()
#     torch.cuda.empty_cache()
#     torch.cuda.reset_max_memory_allocated()
#     torch.cuda.reset_peak_memory_stats()

# def conv_filter_fn(mod, *args):
#     return (isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels])
# def dynamic_quant_filter_fn(mod, *args):
#     return (isinstance(mod, torch.nn.Linear) and mod.in_features > 16 and (mod.in_features, mod.out_features)
#                 not in [(1280, 640), (1920, 1280), (1920, 640), (2048, 1280), (2048, 2560), (2560, 1280), (256, 128), (2816, 1280), (320, 640), (512, 1536), (512, 256), (512, 512), (640, 1280), (640, 1920), (640, 320), (640, 5120), (640, 640), (960, 320), (960, 640)])

# @torch.inference_mode()
# def load_pipeline() -> Pipeline:    
#     clear()
#     dtype, device = torch.bfloat16, "cuda"
#     pipeline = DiffusionPipeline.from_pretrained(
#         MODEL_ID,
#         torch_dtype=dtype,
#         )
    
#     # quantize_(pipeline.vae, int8_dynamic_activation_int8_weight())
#     # quant_api.change_linear_weights_to_int8_dqtensors(pipeline.vae, dynamic_quant_filter_fn) #2.4 pytorch dep
#     # quantize_(pipeline.vae, int8_dynamic_activation_int8_weight())

#     # smooth_fq_linear_to_inference(pipeline.transformer)
#     # quantizer = Int8DynActInt4WeightQuantizer(groupsize=1024)
#     # pipeline.vae = quantizer.quantize(pipeline.vae)

#     # quantize_(pipeline.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()))
#     # quantize_(pipeline.vae, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
#     quantize_(pipeline.vae, float8_weight_only())
    
#     # quant_api.swap_conv2d_1x1_to_linear(pipeline.vae, conv_filter_fn)
    
#     # quant_api.apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
#     # quant_api.apply_weight_only_int8_quant(pipeline.vae, dynamic_quant_filter_fn)
#     # clear()
#     # for param in pipeline.vae.parameters():
#     #     param.detach()
#     # for param in pipeline.transformer.parameters():
#     #     param.detach()
#     # for param in pipeline.text_encoder.parameters():
#     #     param.detach()
#     # for param in pipeline.text_encoder_2.parameters():
#     #     param.detach()
#     # pipeline.enable_sequential_cpu_offload()

#     # swap_linear_with_smooth_fq_linear(pipeline.transformer)
#     # pipeline.transformer.train()
#     for _ in range(2):
#         pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.1, num_inference_steps=4, max_sequence_length=256)
#     # smooth_fq_linear_to_inference(pipeline.transformer)
#     pipeline.enable_sequential_cpu_offload()
#     clear()
#     return pipeline

# @torch.inference_mode()
# def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
#     clear()
#     dir(pipeline)
#     generator = Generator("cuda").manual_seed(request.seed)
#     image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
#     return image


from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler

from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only

from torchao.quantization import quant_api
from deps import f

#from torchao.quantization import autoquant
Pipeline = None

ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    print(f"Flush took: {time.time() - start}")

def conv_filter_fn(mod, *args):
    return (isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels])
def load_pipeline() -> Pipeline:    
    empty_cache()

    dtype, device = torch.bfloat16, "cuda"

    empty_cache()
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id, 
        torch_dtype=dtype,
        )
    
    # quant_api.swap_conv2d_1x1_to_linear(pipeline.vae, f)
    torch.compile(pipeline.vae, mode="max-autotune")
    pipeline.enable_sequential_cpu_offload()
    for _ in range(2):
        empty_cache()
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    empty_cache()
    try:
        generator = Generator("cuda").manual_seed(request.seed)
        image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    except Exception as e:
        print(e)
        print("BLAAAAAAAAAAAAAAAAAAAAAAH")
        image = img.open("./RobertML.png")
        pass
    return(image)