from huggingface_hub.constants import HF_HUB_CACHE from diffusers import FluxPipeline from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator from diffusers import FluxTransformer2DModel import torch import torch._dynamo import gc import os os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" os.environ["TOKENIZERS_PARALLELISM"] = "True" torch._dynamo.config.suppress_errors = True Pipeline = None base_prompt = "insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus" def load_pipeline() -> Pipeline: gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() transformer = FluxTransformer2DModel.from_pretrained(os.path.join(HF_HUB_CACHE, "models--fringuant--StreamCascade/snapshots/765016449ab8494685f030a7db03c67600cf4c55/transformer"), torch_dtype=torch.bfloat16, use_safetensors=False) pipeline = FluxPipeline.from_pretrained("fringuant/StreamCascade", revision="765016449ab8494685f030a7db03c67600cf4c55", transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,) pipeline.to("cuda") pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True, dynamic=True) for idx in range(3): pipeline(prompt=base_prompt, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=4, max_sequence_length=256) return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image: prompt = getattr(request, 'prompt', base_prompt) return pipeline( prompt, generator=Generator(pipeline.device).manual_seed(request.seed), guidance_scale=6.5, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, ).images[0]