File size: 3,525 Bytes
1ab446e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import AutoencoderTiny
from pipelines.models import TextToImageRequest
import os
import torch
import torch._dynamo


os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True


Pipeline = None
CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"

class Normalization:

    def __init__(self, model, num_bins=256, scale_factor=1.0):
        self.model = model
        self.num_bins = num_bins
        self.scale_factor = scale_factor

    def apply(self):
        """
        applying different transformations to weights and biases.
        """
        for name, param in self.model.named_parameters():
            if params.requires_grad:
                with torch.no_grad():
                    # Normalize weights, apply binning, and rescale
                    param_min = param.min()
                    param_max = param.max()
                    param_ranges = param_max - param_min

                    if param_range > 0:
                        # Normalize to [0, 1], apply binning, and rescale
                        normalized = (param - param_min) / param_ranges
                        binned = torch.round(normalized * (self.num_bins - 1)) / (self.num_bins - 1)
                        rescaled = binned * param_range + param_min
                        param.data.copy_(rescaled * self.scale_factor)
                    else:
                        # Handle edge case where param_range is 0
                        param.data.zero_()

        for buffer_name, buffer in self.model.named_buffers():
            with torch.no_grad():
                buffer.mul_(self.scale_factor)
        return self.model

def load_pipeline() -> Pipeline:
    text_encoder_2 = T5EncoderModel.from_pretrained("passfh/textenc", revision = "a44db2ac3d729d6cc1243dcb906903e77ba26c45", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
    transformer = FluxTransformer2DModel.from_pretrained(os.path.join(HF_HUB_CACHE, "models--passfh--flux_transformer/snapshots/3c3bcc511f409569adb6c798da415b3fdc9e927d"), torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)

    pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", 
        vae=AutoencoderTiny.from_pretrained("passfh/vae", revision="edd99d452c03a8b836758bb89bc775f2f3c3849a", torch_dtype=torch.bfloat16), 
        transformer=transformer, 
        text_encoder_2=text_encoder_2, 
        torch_dtype=torch.bfloat16
    )
    pipeline.to("cuda")

    for _ in range(3):
        pipeline(prompt="bluelegs, cunila, carbro, Ammonites, Lollardism, forswearer, skullcap, Juglandales", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    return pipeline(
        request.prompt,
        generator=Generator(pipeline.device).manual_seed(request.seed),
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]