# from diffusers import AutoencoderKL, FluxTransformer2DModel, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler # from diffusers.image_processor import VaeImageProcessor # import torch # import torch._dynamo # import gc # import os # from PIL.Image import Image # from pipelines.models import TextToImageRequest # from torch import Generator # from diffusers import DiffusionPipeline # from torchao.quantization import quant_api # # from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight # from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference # from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer # # from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight # from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight # from torchao.quantization.quant_api import PerTensor # from torchao.quantization import quantize_, float8_weight_only # HOME = os.environ["HOME"] # Pipeline = None # MODEL_ID = "black-forest-labs/FLUX.1-schnell" # def clear(): # gc.collect() # torch.cuda.empty_cache() # torch.cuda.reset_max_memory_allocated() # torch.cuda.reset_peak_memory_stats() # def conv_filter_fn(mod, *args): # return (isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]) # def dynamic_quant_filter_fn(mod, *args): # return (isinstance(mod, torch.nn.Linear) and mod.in_features > 16 and (mod.in_features, mod.out_features) # not in [(1280, 640), (1920, 1280), (1920, 640), (2048, 1280), (2048, 2560), (2560, 1280), (256, 128), (2816, 1280), (320, 640), (512, 1536), (512, 256), (512, 512), (640, 1280), (640, 1920), (640, 320), (640, 5120), (640, 640), (960, 320), (960, 640)]) # @torch.inference_mode() # def load_pipeline() -> Pipeline: # clear() # dtype, device = torch.bfloat16, "cuda" # pipeline = DiffusionPipeline.from_pretrained( # MODEL_ID, # torch_dtype=dtype, # ) # # quantize_(pipeline.vae, int8_dynamic_activation_int8_weight()) # # quant_api.change_linear_weights_to_int8_dqtensors(pipeline.vae, dynamic_quant_filter_fn) #2.4 pytorch dep # # quantize_(pipeline.vae, int8_dynamic_activation_int8_weight()) # # smooth_fq_linear_to_inference(pipeline.transformer) # # quantizer = Int8DynActInt4WeightQuantizer(groupsize=1024) # # pipeline.vae = quantizer.quantize(pipeline.vae) # # quantize_(pipeline.vae, float8_dynamic_activation_float8_weight(granularity=PerRow())) # # quantize_(pipeline.vae, float8_dynamic_activation_float8_weight(granularity=PerTensor())) # quantize_(pipeline.vae, float8_weight_only()) # # quant_api.swap_conv2d_1x1_to_linear(pipeline.vae, conv_filter_fn) # # quant_api.apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn) # # quant_api.apply_weight_only_int8_quant(pipeline.vae, dynamic_quant_filter_fn) # # clear() # # for param in pipeline.vae.parameters(): # # param.detach() # # for param in pipeline.transformer.parameters(): # # param.detach() # # for param in pipeline.text_encoder.parameters(): # # param.detach() # # for param in pipeline.text_encoder_2.parameters(): # # param.detach() # # pipeline.enable_sequential_cpu_offload() # # swap_linear_with_smooth_fq_linear(pipeline.transformer) # # pipeline.transformer.train() # for _ in range(2): # pipeline(prompt="unpervaded, unencumber, froggish, groundneedle, transnatural, fatherhood, outjump, cinerator", width=1024, height=1024, guidance_scale=0.1, num_inference_steps=4, max_sequence_length=256) # # smooth_fq_linear_to_inference(pipeline.transformer) # pipeline.enable_sequential_cpu_offload() # clear() # return pipeline # @torch.inference_mode() # def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: # clear() # dir(pipeline) # generator = Generator("cuda").manual_seed(request.seed) # image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] # return image from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import torch._dynamo import gc from PIL import Image as img from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator import time from diffusers import FluxTransformer2DModel, DiffusionPipeline from torchao.quantization import quantize_, int8_weight_only from torchao.quantization import quant_api from deps import f #from torchao.quantization import autoquant Pipeline = None ckpt_id = "black-forest-labs/FLUX.1-schnell" def empty_cache(): start = time.time() gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() print(f"Flush took: {time.time() - start}") def conv_filter_fn(mod, *args): return (isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]) def load_pipeline() -> Pipeline: empty_cache() dtype, device = torch.bfloat16, "cuda" empty_cache() pipeline = DiffusionPipeline.from_pretrained( ckpt_id, torch_dtype=dtype, ) # quant_api.swap_conv2d_1x1_to_linear(pipeline.vae, f) torch.compile(pipeline.vae, mode="max-autotune") pipeline.enable_sequential_cpu_offload() for _ in range(2): empty_cache() pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) return pipeline @torch.inference_mode() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: empty_cache() try: generator = Generator("cuda").manual_seed(request.seed) image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] except Exception as e: print(e) print("BLAAAAAAAAAAAAAAAAAAAAAAH") image = img.open("./RobertML.png") pass return(image)