File size: 3,090 Bytes
b4bdd09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
#8
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
Pipeline = None
ids = "black-forest-labs/FLUX.1-schnell"
Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
ckpt_id = "agentbot/Quant-Flux-2"
ckpt_revision = "2996318a090bd3bd4b079413d80bbcfd95e6febc"
def load_pipeline() -> Pipeline:
vae = AutoencoderKL.from_pretrained(ids,revision=Revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,)
quantize_(vae, int8_weight_only())
text_encoder_2 = T5EncoderModel.from_pretrained("agentbot/t5-v1_1-xxl-encoder-bf16_", revision = "208e3686b3027985dbd8c9098c273e0155c77ef4", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
path = os.path.join(HF_HUB_CACHE, "models--agentbot--FLUX.1-schnell-int8wo_/snapshots/aa66177be06aba5a88dbe7265255bec48833a936")
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
pipeline.to("cuda")
try:
flux_transformer = FluxTransformer2DModel.from_pretrained(flux_path, torch_dtype=torch.bfloat16, use_safetensors=False)
flux_pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=flux_transformer, local_files_only=True, torch_dtype=torch.bfloat16)
flux_pipeline.to("cuda")
with torch.inference_mode():
flux_pipeline(prompt="satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
except Exception as e:
flux_pipeline = None
if flux_pipeline:
return flux_pipeline
for _ in range(2):
pipeline(prompt="satiety, unwitherable, Pygmy, ramlike, Curtis, fingerstone, rewhisper", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]
|