File size: 2,024 Bytes
0ad923c
09bb5e0
 
 
 
0ad923c
 
09bb5e0
0ad923c
 
 
09bb5e0
 
0ad923c
09bb5e0
 
 
0ad923c
09bb5e0
e396643
2d2557a
0ad923c
 
233b67f
 
b7e1021
09bb5e0
e9160e7
 
09bb5e0
 
0ad923c
 
09bb5e0
0ad923c
 
 
 
 
 
09bb5e0
0ad923c
 
 
 
e396643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#6
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True

Pipeline = None
ids = "slobers/Flux.1.Schnella"
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"

def load_pipeline() -> Pipeline:
    path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
    pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,)
    pipeline.to("cuda")
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
    quantize_(pipeline.vae, int8_weight_only())
    for _ in range(3):
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]