Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -181,40 +181,32 @@ def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps,
|
|
| 181 |
torch.cuda.empty_cache()
|
| 182 |
return None
|
| 183 |
|
| 184 |
-
def
|
| 185 |
-
"""
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
| 188 |
styles = list(style_type_lora_dict.keys())
|
| 189 |
|
| 190 |
-
for
|
| 191 |
-
if i >= 24: # Limit to 24 thumbnails for 6x4 grid
|
| 192 |
-
break
|
| 193 |
-
|
| 194 |
thumbnail_file = thumbnail_mapping.get(style, "")
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
'''
|
| 209 |
-
|
| 210 |
-
# Fill empty slots if needed
|
| 211 |
-
remaining_slots = 24 - len(styles)
|
| 212 |
-
if remaining_slots > 0 and len(styles) < 24:
|
| 213 |
-
for _ in range(remaining_slots):
|
| 214 |
-
html += '<div></div>'
|
| 215 |
|
| 216 |
-
|
| 217 |
-
return html
|
| 218 |
|
| 219 |
# Create Gradio interface
|
| 220 |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
|
|
@@ -228,8 +220,20 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
|
|
| 228 |
|
| 229 |
# Thumbnail Grid Section
|
| 230 |
gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
|
|
|
|
| 231 |
with gr.Row():
|
| 232 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
gr.Markdown("---")
|
| 235 |
|
|
@@ -307,6 +311,22 @@ with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as
|
|
| 307 |
- Use additional instructions for fine control
|
| 308 |
""")
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# Update style description when style changes
|
| 311 |
def update_description(style):
|
| 312 |
return style_descriptions.get(style, "")
|
|
|
|
| 181 |
torch.cuda.empty_cache()
|
| 182 |
return None
|
| 183 |
|
| 184 |
+
def select_style(style_name):
|
| 185 |
+
"""Handler for thumbnail clicks"""
|
| 186 |
+
return style_name, style_descriptions.get(style_name, "")
|
| 187 |
+
|
| 188 |
+
def create_thumbnail_grid():
|
| 189 |
+
"""Create a gallery of style thumbnails"""
|
| 190 |
+
thumbnails = []
|
| 191 |
styles = list(style_type_lora_dict.keys())
|
| 192 |
|
| 193 |
+
for style in styles:
|
|
|
|
|
|
|
|
|
|
| 194 |
thumbnail_file = thumbnail_mapping.get(style, "")
|
| 195 |
+
if thumbnail_file and os.path.exists(thumbnail_file):
|
| 196 |
+
try:
|
| 197 |
+
img = Image.open(thumbnail_file)
|
| 198 |
+
thumbnails.append((img, style.replace('_', ' ')))
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error loading thumbnail {thumbnail_file}: {e}")
|
| 201 |
+
# Create placeholder if thumbnail fails to load
|
| 202 |
+
placeholder = Image.new('RGB', (256, 256), color='lightgray')
|
| 203 |
+
thumbnails.append((placeholder, style.replace('_', ' ')))
|
| 204 |
+
else:
|
| 205 |
+
# Create placeholder for missing thumbnails
|
| 206 |
+
placeholder = Image.new('RGB', (256, 256), color='lightgray')
|
| 207 |
+
thumbnails.append((placeholder, style.replace('_', ' ')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
return thumbnails
|
|
|
|
| 210 |
|
| 211 |
# Create Gradio interface
|
| 212 |
with gr.Blocks(title="FLUX.1 Kontext Style Transfer", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 220 |
|
| 221 |
# Thumbnail Grid Section
|
| 222 |
gr.Markdown("### 🖼️ Click a style thumbnail to select it:")
|
| 223 |
+
|
| 224 |
with gr.Row():
|
| 225 |
+
style_gallery = gr.Gallery(
|
| 226 |
+
value=create_thumbnail_grid(),
|
| 227 |
+
label="Style Thumbnails",
|
| 228 |
+
show_label=False,
|
| 229 |
+
elem_id="style_gallery",
|
| 230 |
+
columns=6,
|
| 231 |
+
rows=4,
|
| 232 |
+
object_fit="cover",
|
| 233 |
+
height="auto",
|
| 234 |
+
interactive=True,
|
| 235 |
+
show_download_button=False
|
| 236 |
+
)
|
| 237 |
|
| 238 |
gr.Markdown("---")
|
| 239 |
|
|
|
|
| 311 |
- Use additional instructions for fine control
|
| 312 |
""")
|
| 313 |
|
| 314 |
+
# Handle gallery selection
|
| 315 |
+
def on_gallery_select(evt: gr.SelectData):
|
| 316 |
+
"""Handle thumbnail selection from gallery"""
|
| 317 |
+
selected_index = evt.index
|
| 318 |
+
styles = list(style_type_lora_dict.keys())
|
| 319 |
+
if 0 <= selected_index < len(styles):
|
| 320 |
+
selected_style = styles[selected_index]
|
| 321 |
+
return selected_style, style_descriptions.get(selected_style, "")
|
| 322 |
+
return None, None
|
| 323 |
+
|
| 324 |
+
style_gallery.select(
|
| 325 |
+
fn=on_gallery_select,
|
| 326 |
+
inputs=None,
|
| 327 |
+
outputs=[style_dropdown, style_info]
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
# Update style description when style changes
|
| 331 |
def update_description(style):
|
| 332 |
return style_descriptions.get(style, "")
|