File size: 15,645 Bytes
7f4c99b
 
 
 
 
 
 
 
c566cc0
7f4c99b
c566cc0
540203d
7f4c99b
 
 
 
 
 
 
540203d
7ae6a53
 
 
1edfe74
99dda54
7ae6a53
 
 
99dda54
 
7ae6a53
99dda54
 
 
 
fbaeee3
99dda54
 
 
7ae6a53
 
 
 
 
 
 
 
 
99dda54
7ae6a53
 
99dda54
 
 
7ae6a53
 
 
1edfe74
7ae6a53
 
 
1edfe74
7ae6a53
 
 
 
 
 
 
1edfe74
7ae6a53
 
 
 
1edfe74
7ae6a53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1edfe74
7ae6a53
 
1edfe74
7ae6a53
 
1edfe74
7ae6a53
540203d
950e246
7ae6a53
 
 
 
 
 
 
 
 
 
1edfe74
7ae6a53
540203d
950e246
 
 
 
 
c566cc0
7ae6a53
1edfe74
 
7ae6a53
1edfe74
7ae6a53
 
 
 
 
aa0cb15
 
c566cc0
7f4c99b
c566cc0
7f4c99b
 
 
c566cc0
 
 
 
 
 
 
0a0ac19
7ae6a53
 
 
7f4c99b
02b00c0
4d8d54b
02b00c0
4d8d54b
02b00c0
4d8d54b
02b00c0
 
 
 
4d8d54b
02b00c0
 
4d8d54b
02b00c0
 
 
 
 
 
 
 
 
 
 
c566cc0
 
 
7be88ba
989c44e
c566cc0
 
 
7f4c99b
c566cc0
0a0ac19
 
c566cc0
 
0a0ac19
 
 
 
 
 
 
c566cc0
 
7ae6a53
0a0ac19
c566cc0
 
 
 
 
 
 
 
 
 
 
 
7ae6a53
 
c566cc0
 
 
 
 
 
7ae6a53
 
 
 
 
 
 
c566cc0
 
7f4c99b
 
 
7ae6a53
 
 
 
7f4c99b
75ba08b
c566cc0
75ba08b
c566cc0
7f4c99b
 
 
 
 
 
c566cc0
7f4c99b
 
02b00c0
7f4c99b
c566cc0
989c44e
 
7f4c99b
 
effd30c
7f4c99b
 
c566cc0
75ba08b
7ae6a53
7f4c99b
02b00c0
 
 
 
 
 
 
 
 
 
 
 
7ae6a53
02b00c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f4c99b
02b00c0
 
7f4c99b
c566cc0
 
 
 
 
 
 
 
7ae6a53
c566cc0
 
 
 
 
 
 
 
7ae6a53
02b00c0
 
 
 
 
7ae6a53
7f4c99b
 
 
4d8d54b
7f4c99b
 
 
950e246
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import io
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
import requests
from PIL import Image
import base64
from gradio_client import Client

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# Initialize text encoder client ONCE at module level to avoid thread exhaustion
text_encoder_client = Client("Gemini899/mistral-text-encoder")

# ============================================================================
# HARDCODED PROMPTS - EXACT match from depth_logic.py (bypasses Mistral API)
# Now with 4 variants: face/no-face for both relief and details
# ============================================================================

HARDCODED_PROMPTS = {
    # From call_flux2_dev_relief_generation() in depth_logic.py - WITH FACE
    "relief_face": "Ignore all shadows, Clay bas-relief sculpture. PRESERVE exact facial features and proportions. Uniform matte gray material, NO black areas, NO dark shadows, NO outlines. Soft smooth depth only. Light gray to white tones. Like carved marble or clay relief. Nose area details soft and delicate, no high contrast.",
    
    # From call_flux2_dev_relief_generation() in depth_logic.py - NO FACE
    "relief": "Ignore all shadows, Clay bas-relief sculpture. PRESERVE exact proportions and features. Uniform matte gray material, NO black areas, NO dark shadows, NO outlines. Soft smooth depth only. Light gray to white tones. Like carved marble or clay relief.",
    
    # From call_flux2_dev_detail_generation() in depth_logic.py - WITH FACE
    "details_face": "Preserve exact pixel alignment. Enhance this depth map with polished smooth skin texture (NOT pixelated), subtle fabric weave, fine hair strands, stone grain. Skin must appear smooth and refined polished. Keep EXACT same outline, silhouette, and tonal range. NO shadows, NO reflections, NO new light sources, NO dark areas under nose/eyes/lips. Leave nose area completely unchanged - preserve original nose values, colors, and details exactly as-is with no enhancement, no contrast boost, and no texture added to nose or nostrils. Output must overlay perfectly on original as bump map detail layer.",
    
    # From call_flux2_dev_detail_generation() in depth_logic.py - NO FACE
    "details": "Preserve exact pixel alignment. Enhance this depth map by adding surface micro-details (fabric texture, hair strands, stone grain) using ONLY tonal variations within ±10% of local gray values. Keep EXACT same outline, silhouette, and overall tonal range. NO shadows, NO reflections, NO new light sources, NO dark areas. Output must overlay perfectly on original as bump map detail layer.",
}

# Pre-load embeddings at startup
_cached_embeddings = {}

def load_cached_embeddings():
    """Load pre-generated embeddings at startup."""
    global _cached_embeddings
    
    # Updated to include all 4 embedding files
    embedding_files = {
        "relief": "relief.pt",
        "relief_face": "relief_face.pt",
        "details": "details.pt",
        "details_face": "details_face.pt",
    }
    
    for key, filename in embedding_files.items():
        # Try multiple possible paths for HuggingFace Spaces
        possible_paths = [
            filename,  # Current directory
            f"/home/user/app/{filename}",  # HF Spaces app directory
            os.path.join(os.path.dirname(os.path.abspath(__file__)), filename),  # Same dir as script
        ]
        
        for path in possible_paths:
            if os.path.exists(path):
                try:
                    _cached_embeddings[key] = torch.load(path, map_location='cpu')
                    print(f"✓ Loaded cached embedding: {key} from {path}")
                    print(f"  Shape: {_cached_embeddings[key].shape}, Dtype: {_cached_embeddings[key].dtype}")
                    break
                except Exception as e:
                    print(f"✗ Error loading {path}: {e}")
        else:
            print(f"⚠ Warning: {filename} not found - will use Mistral API for '{key}' prompt")

def normalize_prompt(prompt: str) -> str:
    """Normalize prompt by stripping whitespace for comparison."""
    return prompt.strip()

def get_cached_embedding(prompt: str) -> torch.Tensor | None:
    """
    Check if prompt EXACTLY matches a hardcoded prompt.
    Returns cached embedding if exact match, None otherwise.
    """
    normalized_input = normalize_prompt(prompt)
    
    for key, hardcoded_prompt in HARDCODED_PROMPTS.items():
        if normalized_input == normalize_prompt(hardcoded_prompt):
            if key in _cached_embeddings:
                print(f"⚡ Exact match found: using cached '{key}' embedding (NO Mistral API call)")
                return _cached_embeddings[key]
            else:
                print(f"⚠ Exact match for '{key}' but no cached embedding file - using Mistral API")
                return None
    
    # No match found
    return None

def remote_text_encoder(prompts):
    """
    Encode text prompts to embeddings.
    Uses cached embeddings for exact hardcoded prompt matches.
    Falls back to Mistral API for all other prompts.
    """
    # Check for exact match with hardcoded prompts
    cached = get_cached_embedding(prompts)
    if cached is not None:
        return cached
    
    # Not an exact match - use Mistral API
    print(f"🌐 Calling Mistral API for prompt encoding...")
    result = text_encoder_client.predict(
        prompt=prompts,
        api_name="/encode_text"
    )
    prompt_embeds = torch.load(result[0])
    return prompt_embeds

# Load cached embeddings at startup
print("="*60)
print("Loading cached prompt embeddings...")
load_cached_embeddings()
print("="*60)

# ============================================================================
# Model Loading
# ============================================================================

repo_id = "black-forest-labs/FLUX.2-dev"

dit = Flux2Transformer2DModel.from_pretrained(
    repo_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16
)

pipe = Flux2Pipeline.from_pretrained(
    repo_id,
    text_encoder=None,
    transformer=dit,
    torch_dtype=torch.bfloat16
)
pipe.to(device)

# ============================================================================
# Image Generation Functions
# ============================================================================

def update_dimensions_from_image(image_list):
    """Update width/height sliders based on uploaded image aspect ratio."""
    if image_list is None or len(image_list) == 0:
        return 1024, 1024
    
    img = image_list[0][0]
    img_width, img_height = img.size
    
    aspect_ratio = img_width / img_height
    
    if aspect_ratio >= 1:
        new_width = 1024
        new_height = int(1024 / aspect_ratio)
    else:
        new_height = 1024
        new_width = int(1024 * aspect_ratio)
    
    new_width = round(new_width / 8) * 8
    new_height = round(new_height / 8) * 8
    
    new_width = max(256, min(1024, new_width))
    new_height = max(256, min(1024, new_height))
    
    return new_width, new_height

def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
    num_images = 0 if image_list is None else len(image_list)
    step_duration = 1 + 0.8 * num_images
    return max(40, num_inference_steps * step_duration + 10)

@spaces.GPU(duration=get_duration)
def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
    prompt_embeds = prompt_embeds.to(device)
    
    generator = torch.Generator(device=device).manual_seed(seed)
    
    pipe_kwargs = {
        "prompt_embeds": prompt_embeds,
        "image": image_list,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "generator": generator,
        "width": width,
        "height": height,
    }
    
    if progress:
        progress(0, desc="Starting generation...")
    
    image = pipe(**pipe_kwargs).images[0]
    return image

def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    image_list = None
    if input_images is not None and len(input_images) > 0:
        image_list = []
        for item in input_images:
            image_list.append(item[0])
    
    # Text Encoding (checks for cached embeddings first)
    progress(0.1, desc="Encoding prompt...")
    prompt_embeds = remote_text_encoder(prompt)
    
    # Image Generation
    progress(0.3, desc="Waiting for GPU...")
    image = generate_image(
        prompt_embeds,
        image_list,
        width,
        height,
        num_inference_steps,
        guidance_scale,
        seed,
        progress
    )
    
    return image, seed

# ============================================================================
# Gradio UI
# ============================================================================

examples = [
    ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
    ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire. Slight upward perspective angle looking up toward the iconic sphere, perfectly centered on clean white background. Left side labels with thin horizontal connector lines: the text '368m' in extra large bold dark grey numerals (#2D3748) positioned at exactly the antenna tip with 'TOTAL HEIGHT' in small caps below. The text '207m' in extra large bold with 'TELECAFÉ' in small caps below, with connector line touching the sphere precisely at the window level. Right side label with horizontal connector line touching the sphere's equator: the text '32m' in extra large bold dark grey numerals with 'SPHERE DIAMETER' in small caps below. Bottom section arranged in three balanced columns: Left - Large text '986' in extra bold dark grey with 'STEPS' in caps below. Center - 'BERLIN TV TOWER' in bold caps with 'FERNSEHTURM' in lighter weight below. Right - 'INAUGURATED' in bold caps with 'OCTOBER 3, 1969' below. All typography in modern sans-serif font (such as Inter or Helvetica), color #2D3748, clean minimal technical diagram style. Horizontal connector lines are thin, precise, and clearly visible, touching the tower structure at exact corresponding measurement points. Professional architectural elevation drawing aesthetic with dynamic low angle perspective creating sense of height and grandeur, poster-ready infographic design with perfect visual hierarchy."],
    ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
    ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose. The design should have smooth rounded lines with black outlines and soft gradient shading with pink cheeks."],
]

examples_images = [
    ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 1200px;
}
.gallery-container img{
    object-fit: contain;
}
"""

with gr.Blocks() as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.2 [dev]
FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and combining images based on text instructions model [[model](https://huggingface.co/black-forest-labs/FLUX.2-dev)], [[blog](https://bfl.ai/blog/flux-2)]
""")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    prompt = gr.Text(
                        label="Prompt",
                        show_label=False,
                        max_lines=2,
                        placeholder="Enter your prompt",
                        container=False,
                        scale=3
                    )
                    
                    run_button = gr.Button("Run", scale=1)
                
                with gr.Accordion("Input image(s) (optional)", open=True):
                    input_images = gr.Gallery(
                        label="Input Image(s)",
                        type="pil",
                        columns=3,
                        rows=1,
                    )
                
                with gr.Accordion("Advanced Settings", open=False):
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    with gr.Row():
                        width = gr.Slider(
                            label="Width",
                            minimum=256,
                            maximum=MAX_IMAGE_SIZE,
                            step=8,
                            value=1024,
                        )
                        
                        height = gr.Slider(
                            label="Height",
                            minimum=256,
                            maximum=MAX_IMAGE_SIZE,
                            step=8,
                            value=1024,
                        )
                    
                    with gr.Row():
                        num_inference_steps = gr.Slider(
                            label="Number of inference steps",
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=30,
                        )
                        
                        guidance_scale = gr.Slider(
                            label="Guidance scale",
                            minimum=0.0,
                            maximum=10.0,
                            step=0.1,
                            value=4,
                        )
                
            with gr.Column():
                result = gr.Image(label="Result", show_label=False)
        
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed],
            cache_examples=True,
            cache_mode="lazy"
        )
        
        gr.Examples(
            examples=examples_images,
            fn=infer,
            inputs=[prompt, input_images],
            outputs=[result, seed],
            cache_examples=True,
            cache_mode="lazy"
        )
    
    input_images.upload(
        fn=update_dimensions_from_image,
        inputs=[input_images],
        outputs=[width, height]
    )
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
        outputs=[result, seed]
    )

demo.launch(css=css)