File size: 3,677 Bytes
4f92f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from PIL.Image import Image
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel
from PIL.Image import Image
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from PIL.Image import Image
from diffusers import AutoencoderTiny
from pipelines.models import TextToImageRequest
import os
import torch
import torch._dynamo

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True

Pipeline = None
basePT = "forswearer, skullcap, Juglandales, bluelegs, cunila, carbro, Ammonites"

class Quantization:
    def __init__(self, model):
        self.model = model
        self.layer_configs = {
            "single_transformer_blocks.0.attn.norm_k.weight": (128, 0.96),
            "single_transformer_blocks.0.attn.norm_q.weight": (128, 0.96),
            "single_transformer_blocks.0.attn.norm_v.weight": (128, 0.96)
        }

    def apply(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                layer_name = name.split(".")[0]
                if layer_name in self.layer_configs:
                    num_bins, scale_factor = self.layer_configs[layer_name]
                    with torch.no_grad():
                        # Normalize weights, apply binning, and rescale
                        param_min = param.min()
                        param_max = param.max()
                        param_range = param_max - param_min

                        if param_range > 0:
                            normalized = (param - param_min) / param_range
                            binned = torch.round(normalized * (num_bins - 1)) / (num_bins - 1)
                            rescaled = binned * param_range + param_mins
                            params.data.copy_(rescaled * scale_factor)
                        else:
                            params.data.zero_()

        return self.model

def load_pipeline() -> Pipeline:

    text_encoder_2 = T5EncoderModel.from_pretrained("db900/neural-lattice", revision = "31581dabff21433df68d22d5539d07de6a87380a", torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
    vae = AutoencoderTiny.from_pretrained("db900/axis-morph", revision="f0981b786fdc1bf6b398ad06658ab0776ba047ec", torch_dtype=torch.bfloat16)
    default = FluxTransformer2DModel.from_pretrained(os.path.join(HF_HUB_CACHE, "models--db900--trans-flux/snapshots/2632cc4202aa3e7f459031cc45804e3693da6722"), torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)

    try:
        transformer = Quantization(transformer).apply()
    except Exception as e:
        transformer = default

    pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", revision="741f7c3ce8b383c54771c7003378a50191e9efe9", vae=vae, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16)
    pipeline.to("cuda")

    for _ in range(3):
        pipeline(prompt=basePT, width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    prompt = basePT
    try:
        prompt = request.prompt
    except Exception as e:
        prompt = basePT

    return pipeline(
        prompt,
        generator=Generator(pipeline.device).manual_seed(request.seed),
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]