File size: 3,983 Bytes
3770736
ff95496
3770736
 
d000ac9
 
 
3770736
 
ff95496
d000ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff95496
d000ac9
 
ff95496
d000ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3770736
d000ac9
3770736
 
d000ac9
 
 
 
3770736
d000ac9
3770736
d000ac9
3770736
d000ac9
 
 
 
3770736
d000ac9
 
 
 
 
3770736
d000ac9
 
 
 
 
 
 
3770736
d000ac9
 
ff95496
3770736
d000ac9
 
 
 
 
ff95496
 
d000ac9
 
 
 
 
 
 
 
 
ff95496
3770736
d000ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3770736
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import gc
import torch
from PIL.Image import Image
from torch import Generator
from typing import Optional, Dict, Any
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from pipelines.models import TextToImageRequest
from diffusers import FluxTransformer2DModel

# Environment configuration
MODEL_CONFIG = {
    "repository": "black-forest-labs/FLUX.1-schnell",
    "revision": "741f7c3ce8b383c54771c7003378a50191e9efe9",
    "compute_device": "cuda",
    "precision": torch.bfloat16,
    "memory_allocation": "expandable_segments:True"
}

# Setup CUDA optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = MODEL_CONFIG["memory_allocation"]

def reclaim_memory():
    """Release unused GPU memory resources"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.reset_peak_memory_stats()

def acquire_text_encoder() -> T5EncoderModel:
    """Fetch and prepare the text encoder component"""
    encoder_params = {
        "pretrained_model_name_or_path": "manbeast3b/flux.1-schnell-full1",
        "revision": "cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
        "subfolder": "text_encoder_2",
        "torch_dtype": MODEL_CONFIG["precision"]
    }
    return T5EncoderModel.from_pretrained(**encoder_params)

def acquire_transformer() -> FluxTransformer2DModel:
    """Fetch and prepare the transformer component"""
    cache_location = os.path.join(
        HF_HUB_CACHE,
        "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
        "transformer"
    )

    transformer = FluxTransformer2DModel.from_pretrained(
        cache_location,
        torch_dtype=MODEL_CONFIG["precision"],
        use_safetensors=False
    )

    return transformer.to(memory_format=torch.channels_last)

def initialize_pipeline(components: Optional[Dict[str, Any]] = None) -> DiffusionPipeline:
    """Construct and initialize the diffusion pipeline"""
    if components is None:
        components = {}

    if "text_encoder_2" not in components:
        components["text_encoder_2"] = acquire_text_encoder()

    if "transformer" not in components:
        components["transformer"] = acquire_transformer()

    # Create pipeline with components
    pipeline = DiffusionPipeline.from_pretrained(
        MODEL_CONFIG["repository"],
        revision=MODEL_CONFIG["revision"],
        torch_dtype=MODEL_CONFIG["precision"],
        **components
    )

    # Configure pipeline
    pipeline.to(MODEL_CONFIG["compute_device"])
    pipeline.to(memory_format=torch.channels_last)

    # Warm up with empty prompts
    for _ in range(2):
        with torch.no_grad():
            pipeline(prompt=" ")

    return pipeline

def load_pipeline() -> DiffusionPipeline:
    """
    Public interface to load the model pipeline

    Returns:
        A configured diffusion pipeline ready for inference
    """
    return initialize_pipeline()

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
    """
    Generate an image from a text prompt

    Args:
        request: The text-to-image generation request
        pipeline: The diffusion pipeline
        generator: Random number generator with seed

    Returns:
        A PIL image generated from the prompt
    """
    generation_params = {
        "prompt": request.prompt,
        "generator": generator,
        "guidance_scale": 0.0,
        "num_inference_steps": 4,
        "max_sequence_length": 256,
        "height": request.height,
        "width": request.width,
        "output_type": "pil"
    }

    result = pipeline(**generation_params)
    return result.images[0]

# Alias for backward compatibility
load = load_pipeline