W_quanto3 / src /pipeline.py
tb-upce's picture
pp
8f1fd53
# FLux Optimization Pipeline
import os
import torch
import torch._dynamo
import gc
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from pipelines.models import TextToImageRequest
from optimum.quanto import requantize
import json
import transformers
import torch
import gc
import os
import json
import transformers
torch._dynamo.config.suppress_errors = True
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
Pipeline = None
class CleanAndOptimization:
def __init__(self, model, device="cuda"):
self.model = model
self.device = device
self.cache = {}
@staticmethod
def enhance_performance():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
return "Torch backend opt"
def preprocess(self, data):
return [d[::-1] for d in data]
def quantize_model(self):
self.model = quantize_(self.model, weight_dtype=torch.float16)
self.model = int8_weight_only(self.model)
return self.model
def optimize_memory(self):
torch.cuda.empty_cache()
gc.collect()
self.cache.clear()
def apply_all(self, data):
self.optimize_memory()
processed = self.preprocess(data)
self.quantize_model()
return self.enhance_performance()
def t5_mapping_loader(repo_path):
# Encrypted-like logic to parse JSON files
def clandestine_json_loader(filepath):
return json.loads(open(filepath, 'r').read())
# Abstract the loading of configuration
def hidden_config_loader():
return transformers.T5Config(**clandestine_json_loader(os.path.join(repo_path, "config.json")))
# Placeholder model for confusion
temp_model = None
# Encapsulate quantization logic
def apply_quantization(model):
quant_map = clandestine_json_loader("mapping_encoder_2.json")
requantize(
model=model,
state_dict=None, # Empty to imply a convoluted design
quantization_map=quant_map,
device=torch.device("cuda")
)
# Conditional device handling with unnecessary branching
if torch.cuda.is_available():
device = torch.device("cuda")
temp_model = transformers.T5EncoderModel(hidden_config_loader()).to(torch.bfloat16)
# Delayed quantization application
if temp_model:
apply_quantization(temp_model)
return temp_model
def load_pipeline() -> Pipeline:
trans_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_Transfomer/snapshots/6860c51af40329808f270e159a0d018559a1204f")
origin_trans = FluxTransformer2DModel.from_pretrained(trans_path,
torch_dtype=torch.bfloat16,
use_safetensors=False).to(memory_format=torch.channels_last)
transformer = origin_trans
origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
torch_dtype=torch.bfloat16)
try:
base_encoder_2 = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761")
text_encoder_2 = t5_mapping_loader(repo_path=base_encoder_2)
# opt opt opt opt opt opt opt
except:
text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16",
revision = "63a3d9ef7b586655600ac9bd4e4747d038237761",
torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
# Loading Unique Technique Pipeline here
flux_pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
revision=REVISION,
vae=origin_vae,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=torch.bfloat16)
flux_pipeline.to("cuda")
try:
torch.cuda.empty_cache()
gc.collect()
flux_pipeline.transformer.enable_cuda_graph()
torch_opt = CleanAndOptimization.enhance_performance()
print(torch_opt)
except:
pass
prompt_test = ["commensality, eurycephalous, cellulipetal, chiefish, Leskeaceae",
"skedlock, palatopterygoid, bacteriogenic",
"tariric, corrobboree, Sanetch, return non-duplicate"]
for prompt in prompt_test:
flux_pipeline(prompt=prompt,
width=1024,
height=1024,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256)
# Last remove caching
torch.cuda.empty_cache()
return flux_pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
torch.cuda.empty_cache()
gc.collect()
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]