GO_SPORTS_ / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
caf155f verified
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
AutoencoderTiny,
FluxPipeline,
FluxTransformer2DModel
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import (
T5EncoderModel,
T5TokenizerFast,
CLIPTokenizer,
CLIPTextModel
)
import torch
import torch._dynamo
import gc
from PIL import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
import math
from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchao.quantization import quantize_, swap_linear_with_smooth_fq_linear, float8_weight_only, uintx_weight_only
from utils import _load
import torchvision
import os
# preconfigs
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# globals
Pipeline = None
ckpt_id = "black-forest-labs/FLUX.1-schnell"
ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
ckpt_id = "manbeast3b/flux.1-schnell-full1"
ckpt_revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
TinyVAE = "madebyollin/taef1"
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
def filter_state_dict(model, state_dict_path):
global E
state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
prefix = 'encoder.' if type(model) == E else 'decoder.'
return {k.strip(prefix): v for k, v in state_dict.items() if k.strip(prefix) in model.state_dict() and v.size() == model.state_dict()[k.strip(prefix)].size()}
def load_pipeline() -> Pipeline:
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
vae = AutoencoderTiny.from_pretrained(
TinyVAE,
revision=TinyVAE_REV,
local_files_only=True,
torch_dtype=torch.bfloat16)
vae.encoder=_load(vae.encoder, "E", dtype=torch.bfloat16); vae.decoder=_load(vae.decoder, "D", dtype=torch.bfloat16)
pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, vae=vae, local_files_only=True, torch_dtype=torch.bfloat16,)
pipeline.to("cuda")
pipeline.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
# pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
quantize_(pipeline.vae, float8_weight_only())
warmup_ = "controllable varied focus thai warriors entertainment blue golden pink soft tough padthai"
for _ in range(2):
pipeline(prompt=warmup_, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
# empty_cache()
return pipeline
sample = 1
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
global sample
if not sample:
sample=1
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))# torchvision.transforms.functional.to_pil_image(image)