multimodalart HF Staff commited on
Commit
ca4c7b6
·
verified ·
1 Parent(s): 44e242d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -6
app.py CHANGED
@@ -78,6 +78,15 @@ DEFAULT_CFG = {
78
  }
79
 
80
  def image_to_data_uri(img):
 
 
 
 
 
 
 
 
 
81
  buffered = io.BytesIO()
82
  img.save(buffered, format="PNG")
83
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -85,6 +94,16 @@ def image_to_data_uri(img):
85
 
86
 
87
  def upsample_prompt_logic(prompt, image_list):
 
 
 
 
 
 
 
 
 
 
88
  try:
89
  if image_list and len(image_list) > 0:
90
  # Image + Text Editing Mode
@@ -125,8 +144,18 @@ def upsample_prompt_logic(prompt, image_list):
125
 
126
 
127
  def update_dimensions_from_image(image_list):
128
- """Update width/height sliders based on uploaded image aspect ratio.
129
- Keeps one side at 1024 and scales the other proportionally, with both sides as multiples of 8."""
 
 
 
 
 
 
 
 
 
 
130
  if image_list is None or len(image_list) == 0:
131
  return 1024, 1024 # Default dimensions
132
 
@@ -155,12 +184,69 @@ def update_dimensions_from_image(image_list):
155
 
156
 
157
  def update_steps_from_mode(mode_choice):
158
- """Update the number of inference steps based on the selected mode."""
 
 
 
 
 
 
 
 
159
  return DEFAULT_STEPS[mode_choice], DEFAULT_CFG[mode_choice]
160
 
161
 
162
  @spaces.GPU(duration=85)
163
- def infer(prompt, input_images=None, mode_choice="Distilled (4 steps)", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  if randomize_seed:
166
  seed = random.randint(0, MAX_SEED)
@@ -356,7 +442,9 @@ FLUX.2 [Klein] is a distilled model capable of generating, editing and combining
356
  triggers=[run_button.click, prompt.submit],
357
  fn=infer,
358
  inputs=[prompt, input_images, mode_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
359
- outputs=[result, seed]
 
360
  )
361
 
362
- demo.launch()
 
 
78
  }
79
 
80
  def image_to_data_uri(img):
81
+ """
82
+ Convert a PIL Image to a base64 data URI.
83
+
84
+ Args:
85
+ img: The PIL Image to convert.
86
+
87
+ Returns:
88
+ str: A data URI string containing the base64-encoded PNG image.
89
+ """
90
  buffered = io.BytesIO()
91
  img.save(buffered, format="PNG")
92
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
94
 
95
 
96
  def upsample_prompt_logic(prompt, image_list):
97
+ """
98
+ Enhance a text prompt using a Vision-Language Model.
99
+
100
+ Args:
101
+ prompt (str): The original text prompt to enhance.
102
+ image_list: Optional list of PIL Images for context-aware enhancement.
103
+
104
+ Returns:
105
+ str: The enhanced prompt, or the original prompt if enhancement fails.
106
+ """
107
  try:
108
  if image_list and len(image_list) > 0:
109
  # Image + Text Editing Mode
 
144
 
145
 
146
  def update_dimensions_from_image(image_list):
147
+ """
148
+ Update width/height based on uploaded image aspect ratio.
149
+
150
+ Keeps one side at 1024 and scales the other proportionally,
151
+ with both sides as multiples of 8.
152
+
153
+ Args:
154
+ image_list: Gallery list of tuples (image, caption) from Gradio.
155
+
156
+ Returns:
157
+ tuple: A tuple of (width, height) integers, both multiples of 8.
158
+ """
159
  if image_list is None or len(image_list) == 0:
160
  return 1024, 1024 # Default dimensions
161
 
 
184
 
185
 
186
  def update_steps_from_mode(mode_choice):
187
+ """
188
+ Update inference steps and guidance scale based on the selected mode.
189
+
190
+ Args:
191
+ mode_choice (str): The selected mode, either "Distilled (4 steps)" or "Base (50 steps)".
192
+
193
+ Returns:
194
+ tuple: A tuple of (num_inference_steps, guidance_scale).
195
+ """
196
  return DEFAULT_STEPS[mode_choice], DEFAULT_CFG[mode_choice]
197
 
198
 
199
  @spaces.GPU(duration=85)
200
+ def infer(
201
+ prompt: str,
202
+ input_images=None,
203
+ mode_choice: str = "Distilled (4 steps)",
204
+ seed: int = 42,
205
+ randomize_seed: bool = False,
206
+ width: int = 1024,
207
+ height: int = 1024,
208
+ num_inference_steps: int = 4,
209
+ guidance_scale: float = 4.0,
210
+ prompt_upsampling: bool = False,
211
+ progress=gr.Progress(track_tqdm=True)
212
+ ):
213
+ """
214
+ Generate or edit images using FLUX.2 Klein 9B model.
215
+
216
+ This tool can generate images from text prompts, or edit/combine existing images
217
+ based on text instructions. Use the distilled mode for fast 4-step generation,
218
+ or base mode for higher quality 50-step generation.
219
+
220
+ Args:
221
+ prompt (str): Text description of the image to generate, or editing instructions when input images are provided.
222
+ input_images: Optional list of input images for editing or combining. Provide image URLs.
223
+ mode_choice (str): Model mode - "Distilled (4 steps)" for fast generation or "Base (50 steps)" for higher quality.
224
+ seed (str): Random seed for reproducible generation. Use "0" with randomize_seed=True for random results.
225
+ randomize_seed (str): Set to "true" to use a random seed, "false" to use the specified seed.
226
+ width (str): Output image width in pixels (256-1024, must be multiple of 8).
227
+ height (str): Output image height in pixels (256-1024, must be multiple of 8).
228
+ num_inference_steps (str): Number of denoising steps. Use "4" for distilled mode, "50" for base mode.
229
+ guidance_scale (str): How closely to follow the prompt. Use "1.0" for distilled, "4.0" for base mode.
230
+ prompt_upsampling (str): Set to "true" to automatically enhance the prompt using a VLM.
231
+
232
+ Returns:
233
+ tuple: A tuple containing the generated PIL Image and the seed used.
234
+ """
235
+ # Convert string inputs to proper types for MCP compatibility
236
+ if isinstance(seed, str):
237
+ seed = int(seed)
238
+ if isinstance(randomize_seed, str):
239
+ randomize_seed = randomize_seed.lower() == "true"
240
+ if isinstance(width, str):
241
+ width = int(width)
242
+ if isinstance(height, str):
243
+ height = int(height)
244
+ if isinstance(num_inference_steps, str):
245
+ num_inference_steps = int(num_inference_steps)
246
+ if isinstance(guidance_scale, str):
247
+ guidance_scale = float(guidance_scale)
248
+ if isinstance(prompt_upsampling, str):
249
+ prompt_upsampling = prompt_upsampling.lower() == "true"
250
 
251
  if randomize_seed:
252
  seed = random.randint(0, MAX_SEED)
 
442
  triggers=[run_button.click, prompt.submit],
443
  fn=infer,
444
  inputs=[prompt, input_images, mode_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
445
+ outputs=[result, seed],
446
+ api_name="generate" # Explicit API name for MCP tool
447
  )
448
 
449
+ # Launch with MCP server enabled
450
+ demo.launch(mcp_server=True)