File size: 2,108 Bytes
514c457
c8cbfa7
 
 
 
efda35d
 
514c457
efda35d
 
 
c8cbfa7
 
efda35d
c8cbfa7
efda35d
 
 
 
58d4a01
 
efda35d
 
58d4a01
 
 
efda35d
58d4a01
 
514c457
c8cbfa7
 
efda35d
 
c8cbfa7
efda35d
 
 
 
 
 
3897c9d
 
efda35d
 
 
c8cbfa7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#7.1
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True

Pipeline = None
ids = "slobers/Flux.1.Schnella"
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"

def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
    pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
    pipeline.to("cuda")
    quantize_(pipeline.vae, int8_weight_only())
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
    pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True, dynamic=True)
    for _ in range(3):
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=3.5,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]