File size: 3,210 Bytes
8f22133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from diffusers import (
    DiffusionPipeline,
    AutoencoderKL,
    FluxPipeline,
    FluxTransformer2DModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
import math
from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

# preconfigs
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True

# globals
Pipeline = None
ckpt_id = "freaky231/t5-encoder-bf16"
ckpt_revision = "994f6e4720f69e67bfc8822cbb4063c9149b801b"


def empty_cache():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


def load_pipeline() -> Pipeline:
    vae = AutoencoderKL.from_pretrained(
        ckpt_id,
        revision=ckpt_revision,
        subfolder="vae",
        local_files_only=True,
        torch_dtype=torch.bfloat16,
    )
    quantize_(vae, int8_weight_only())
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "freaky231/FluxPipeline",
        revision="c5cf4b2fc96d25c81eb0783d2c362689ea9ccf28",
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
    )
    path = os.path.join(
        HF_HUB_CACHE,
        "models--freaky231--FluxPipeline/snapshots/c5cf4b2fc96d25c81eb0783d2c362689ea9ccf28/transformer",
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        path, torch_dtype=torch.bfloat16, use_safetensors=False
    )
    pipeline = FluxPipeline.from_pretrained(
        ckpt_id,
        revision=ckpt_revision,
        transformer=transformer,
        text_encoder_2=text_encoder_2,
        torch_dtype=torch.bfloat16,
    )
    pipeline.to("cuda")
    pipeline.to(memory_format=torch.channels_last)
    for _ in range(1):
        pipeline(
            prompt="unaware, kettledrum, clayey, bioenergetic, radiograph, locomotion, subcortical, microtubule",
            width=1024,
            height=1024,
            guidance_scale=0.0,
            num_inference_steps=4,
            max_sequence_length=256,
        )
    return pipeline


sample = 1


@torch.no_grad()
def infer(
    request: TextToImageRequest, pipeline: Pipeline, generator: Generator
) -> Image:
    global sample
    if not sample:
        sample = 1
        empty_cache()
    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
        output_type="pil",
    ).images[0]