File size: 2,688 Bytes
c01aef2
 
 
 
 
64c779b
 
c01aef2
64c779b
 
 
c01aef2
 
64c779b
c01aef2
 
 
b9ee776
 
 
c01aef2
 
 
b9ee776
 
 
 
 
64c779b
 
b9ee776
e127697
c01aef2
 
 
e127697
c01aef2
b9ee776
 
c01aef2
b9ee776
c01aef2
 
64c779b
 
c01aef2
64c779b
 
b9ee776
64c779b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#2
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.95)
Pipeline = None
ids = "black-forest-labs/FLUX.1-schnell"
Revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:
    empty_cache()
    vae = AutoencoderTiny.from_pretrained("slobers/tt1",revision="ec746bf42d91e3335760895281f070df54f2196a", torch_dtype=torch.bfloat16,)
    text_encoder_2 = T5EncoderModel.from_pretrained("city96/t5-v1_1-xxl-encoder-bf16", revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
    path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
    pipeline = DiffusionPipeline.from_pretrained(ids, revision=Revision, vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
    pipeline.to("cuda")
    pipeline.vae.enable_tiling()
    pipeline.vae.enable_slicing()

    empty_cache()
    for _ in range(3):
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)
    empty_cache()
    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]