Gemini899 commited on
Commit
4d8d54b
·
verified ·
1 Parent(s): 175171d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -126
app.py CHANGED
@@ -6,12 +6,9 @@ import random
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
9
- from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig
10
  import requests
11
  from PIL import Image
12
- import json
13
  import base64
14
- from huggingface_hub import InferenceClient
15
 
16
  dtype = torch.bfloat16
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -19,45 +16,25 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
- hf_client = InferenceClient(
23
- api_key=os.environ.get("HF_TOKEN"),
24
- )
25
- VLM_MODEL = "baidu/ERNIE-4.5-VL-424B-A47B-Base-PT"
26
-
27
- SYSTEM_PROMPT_TEXT_ONLY = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
28
-
29
- Guidelines:
30
- 1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
31
- 2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
32
- 3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
33
-
34
- Output only the revised prompt and nothing else."""
35
-
36
- SYSTEM_PROMPT_WITH_IMAGES = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
37
-
38
- Rules:
39
- - Single instruction only, no commentary
40
- - Use clear, analytical language (avoid "whimsical," "cascading," etc.)
41
- - Specify what changes AND what stays the same (face, lighting, composition)
42
- - Reference actual image elements
43
- - Turn negatives into positives ("don't change X" → "keep X")
44
- - Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
45
- - Keep content PG-13
46
-
47
- Output only the final instruction in plain text and nothing else."""
48
-
49
- def remote_text_encoder(prompts):
50
  from gradio_client import Client
 
51
 
52
- client = Client("multimodalart/mistral-text-encoder")
53
- result = client.predict(
54
- prompt=prompts,
55
- api_name="/encode_text"
56
- )
57
-
58
- # Load returns a tensor, usually on CPU by default
59
- prompt_embeds = torch.load(result[0])
60
- return prompt_embeds
 
 
 
 
 
 
61
 
62
  # Load model
63
  repo_id = "black-forest-labs/FLUX.2-dev"
@@ -76,85 +53,34 @@ pipe = Flux2Pipeline.from_pretrained(
76
  )
77
  pipe.to(device)
78
 
79
- # AOTI blocks temporarily disabled - HuggingFace needs to recompile for new ZeroGPU environment (PyTorch 2.9 + CUDA 12.8)
80
- # Re-enable once zerogpu-aoti/FLUX.2 is updated with compatible compiled blocks
81
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
82
 
83
- def image_to_data_uri(img):
84
- buffered = io.BytesIO()
85
- img.save(buffered, format="PNG")
86
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
87
- return f"data:image/png;base64,{img_str}"
88
-
89
- def upsample_prompt_logic(prompt, image_list):
90
- try:
91
- if image_list and len(image_list) > 0:
92
- # Image + Text Editing Mode
93
- system_content = SYSTEM_PROMPT_WITH_IMAGES
94
-
95
- # Construct user message with text and images
96
- user_content = [{"type": "text", "text": prompt}]
97
-
98
- for img in image_list:
99
- data_uri = image_to_data_uri(img)
100
- user_content.append({
101
- "type": "image_url",
102
- "image_url": {"url": data_uri}
103
- })
104
-
105
- messages = [
106
- {"role": "system", "content": system_content},
107
- {"role": "user", "content": user_content}
108
- ]
109
- else:
110
- # Text Only Mode
111
- system_content = SYSTEM_PROMPT_TEXT_ONLY
112
- messages = [
113
- {"role": "system", "content": system_content},
114
- {"role": "user", "content": prompt}
115
- ]
116
-
117
- completion = hf_client.chat.completions.create(
118
- model=VLM_MODEL,
119
- messages=messages,
120
- max_tokens=1024
121
- )
122
-
123
- return completion.choices[0].message.content
124
- except Exception as e:
125
- print(f"Upsampling failed: {e}")
126
- return prompt
127
-
128
  def update_dimensions_from_image(image_list):
129
- """Update width/height sliders based on uploaded image aspect ratio.
130
- Keeps one side at 1024 and scales the other proportionally, with both sides as multiples of 8."""
131
  if image_list is None or len(image_list) == 0:
132
- return 1024, 1024 # Default dimensions
133
 
134
- # Get the first image to determine dimensions
135
- img = image_list[0][0] # Gallery returns list of tuples (image, caption)
136
  img_width, img_height = img.size
137
 
138
  aspect_ratio = img_width / img_height
139
 
140
- if aspect_ratio >= 1: # Landscape or square
141
  new_width = 1024
142
  new_height = int(1024 / aspect_ratio)
143
- else: # Portrait
144
  new_height = 1024
145
  new_width = int(1024 * aspect_ratio)
146
 
147
- # Round to nearest multiple of 8
148
  new_width = round(new_width / 8) * 8
149
  new_height = round(new_height / 8) * 8
150
 
151
- # Ensure within valid range (minimum 256, maximum 1024)
152
  new_width = max(256, min(1024, new_width))
153
  new_height = max(256, min(1024, new_height))
154
 
155
  return new_width, new_height
156
 
157
- # Updated duration function to match generate_image arguments (including progress)
158
  def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
159
  num_images = 0 if image_list is None else len(image_list)
160
  step_duration = 1 + 0.8 * num_images
@@ -162,7 +88,6 @@ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps,
162
 
163
  @spaces.GPU(duration=get_duration)
164
  def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
165
- # Move embeddings to GPU only when inside the GPU decorated function
166
  prompt_embeds = prompt_embeds.to(device)
167
 
168
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -177,39 +102,28 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
177
  "height": height,
178
  }
179
 
180
- # Progress bar for the actual generation steps
181
  if progress:
182
  progress(0, desc="Starting generation...")
183
 
184
  image = pipe(**pipe_kwargs).images[0]
185
  return image
186
 
187
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
188
 
189
  if randomize_seed:
190
  seed = random.randint(0, MAX_SEED)
191
 
192
- # Prepare image list (convert None or empty gallery to None)
193
  image_list = None
194
  if input_images is not None and len(input_images) > 0:
195
  image_list = []
196
  for item in input_images:
197
  image_list.append(item[0])
198
 
199
- # 1. Upsampling (Network bound - No GPU needed)
200
- final_prompt = prompt
201
- if prompt_upsampling:
202
- progress(0.05, desc="Upsampling prompt...")
203
- final_prompt = upsample_prompt_logic(prompt, image_list)
204
- print(f"Original Prompt: {prompt}")
205
- print(f"Upsampled Prompt: {final_prompt}")
206
-
207
- # 2. Text Encoding (Network bound - No GPU needed)
208
  progress(0.1, desc="Encoding prompt...")
209
- # This returns CPU tensors
210
- prompt_embeds = remote_text_encoder(final_prompt)
211
 
212
- # 3. Image Generation (GPU bound)
213
  progress(0.3, desc="Waiting for GPU...")
214
  image = generate_image(
215
  prompt_embeds,
@@ -232,7 +146,6 @@ examples = [
232
  ]
233
 
234
  examples_images = [
235
- # ["Replace the top of the person from image 1 with the one from image 2", ["person1.webp", "woman2.webp"]],
236
  ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
237
  ]
238
 
@@ -275,12 +188,6 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
275
  )
276
 
277
  with gr.Accordion("Advanced Settings", open=False):
278
- prompt_upsampling = gr.Checkbox(
279
- label="Prompt Upsampling",
280
- value=True,
281
- info="Automatically enhance the prompt using a VLM"
282
- )
283
-
284
  seed = gr.Slider(
285
  label="Seed",
286
  minimum=0,
@@ -292,7 +199,6 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
292
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
293
 
294
  with gr.Row():
295
-
296
  width = gr.Slider(
297
  label="Width",
298
  minimum=256,
@@ -310,7 +216,6 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
310
  )
311
 
312
  with gr.Row():
313
-
314
  num_inference_steps = gr.Slider(
315
  label="Number of inference steps",
316
  minimum=1,
@@ -327,10 +232,8 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
327
  value=4,
328
  )
329
 
330
-
331
  with gr.Column():
332
  result = gr.Image(label="Result", show_label=False)
333
-
334
 
335
  gr.Examples(
336
  examples=examples,
@@ -350,7 +253,6 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
350
  cache_mode="lazy"
351
  )
352
 
353
- # Auto-update dimensions when images are uploaded
354
  input_images.upload(
355
  fn=update_dimensions_from_image,
356
  inputs=[input_images],
@@ -360,7 +262,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
360
  gr.on(
361
  triggers=[run_button.click, prompt.submit],
362
  fn=infer,
363
- inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
364
  outputs=[result, seed]
365
  )
366
 
 
6
  import spaces
7
  import torch
8
  from diffusers import Flux2Pipeline, Flux2Transformer2DModel
 
9
  import requests
10
  from PIL import Image
 
11
  import base64
 
12
 
13
  dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 1024
18
 
19
+ def remote_text_encoder(prompts, max_retries=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from gradio_client import Client
21
+ import time
22
 
23
+ for attempt in range(max_retries):
24
+ try:
25
+ client = Client("multimodalart/mistral-text-encoder")
26
+ result = client.predict(
27
+ prompt=prompts,
28
+ api_name="/encode_text"
29
+ )
30
+ prompt_embeds = torch.load(result[0])
31
+ return prompt_embeds
32
+ except Exception as e:
33
+ print(f"Text encoder attempt {attempt + 1}/{max_retries} failed: {e}")
34
+ if attempt < max_retries - 1:
35
+ time.sleep(2)
36
+ else:
37
+ raise Exception(f"Text encoder failed after {max_retries} attempts: {e}")
38
 
39
  # Load model
40
  repo_id = "black-forest-labs/FLUX.2-dev"
 
53
  )
54
  pipe.to(device)
55
 
56
+ # AOTI blocks temporarily disabled - HuggingFace needs to recompile for new ZeroGPU environment
 
57
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def update_dimensions_from_image(image_list):
60
+ """Update width/height sliders based on uploaded image aspect ratio."""
 
61
  if image_list is None or len(image_list) == 0:
62
+ return 1024, 1024
63
 
64
+ img = image_list[0][0]
 
65
  img_width, img_height = img.size
66
 
67
  aspect_ratio = img_width / img_height
68
 
69
+ if aspect_ratio >= 1:
70
  new_width = 1024
71
  new_height = int(1024 / aspect_ratio)
72
+ else:
73
  new_height = 1024
74
  new_width = int(1024 * aspect_ratio)
75
 
 
76
  new_width = round(new_width / 8) * 8
77
  new_height = round(new_height / 8) * 8
78
 
 
79
  new_width = max(256, min(1024, new_width))
80
  new_height = max(256, min(1024, new_height))
81
 
82
  return new_width, new_height
83
 
 
84
  def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
85
  num_images = 0 if image_list is None else len(image_list)
86
  step_duration = 1 + 0.8 * num_images
 
88
 
89
  @spaces.GPU(duration=get_duration)
90
  def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
 
91
  prompt_embeds = prompt_embeds.to(device)
92
 
93
  generator = torch.Generator(device=device).manual_seed(seed)
 
102
  "height": height,
103
  }
104
 
 
105
  if progress:
106
  progress(0, desc="Starting generation...")
107
 
108
  image = pipe(**pipe_kwargs).images[0]
109
  return image
110
 
111
+ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
112
 
113
  if randomize_seed:
114
  seed = random.randint(0, MAX_SEED)
115
 
 
116
  image_list = None
117
  if input_images is not None and len(input_images) > 0:
118
  image_list = []
119
  for item in input_images:
120
  image_list.append(item[0])
121
 
122
+ # Text Encoding
 
 
 
 
 
 
 
 
123
  progress(0.1, desc="Encoding prompt...")
124
+ prompt_embeds = remote_text_encoder(prompt)
 
125
 
126
+ # Image Generation
127
  progress(0.3, desc="Waiting for GPU...")
128
  image = generate_image(
129
  prompt_embeds,
 
146
  ]
147
 
148
  examples_images = [
 
149
  ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
150
  ]
151
 
 
188
  )
189
 
190
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
191
  seed = gr.Slider(
192
  label="Seed",
193
  minimum=0,
 
199
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
200
 
201
  with gr.Row():
 
202
  width = gr.Slider(
203
  label="Width",
204
  minimum=256,
 
216
  )
217
 
218
  with gr.Row():
 
219
  num_inference_steps = gr.Slider(
220
  label="Number of inference steps",
221
  minimum=1,
 
232
  value=4,
233
  )
234
 
 
235
  with gr.Column():
236
  result = gr.Image(label="Result", show_label=False)
 
237
 
238
  gr.Examples(
239
  examples=examples,
 
253
  cache_mode="lazy"
254
  )
255
 
 
256
  input_images.upload(
257
  fn=update_dimensions_from_image,
258
  inputs=[input_images],
 
262
  gr.on(
263
  triggers=[run_button.click, prompt.submit],
264
  fn=infer,
265
+ inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
266
  outputs=[result, seed]
267
  )
268