prithivMLmods commited on
Commit
76c1a5e
Β·
verified Β·
1 Parent(s): fbc137a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -218
app.py CHANGED
@@ -1,270 +1,192 @@
1
  import os
2
- import gc
3
  import gradio as gr
4
  import numpy as np
5
- import random
6
  import spaces
7
  import torch
8
- from diffusers import Flux2KleinPipeline
9
  from PIL import Image
 
10
 
11
- # --- Setup ---
12
- dtype = torch.bfloat16
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
- MAX_IMAGE_SIZE = 1024
17
 
18
  # --- Model Loading ---
19
- # Only load the 9B distilled model as requested.
20
- REPO_ID_DISTILLED = "black-forest-labs/FLUX.2-klein-9B"
21
-
22
- print(f"Loading FLUX.2 Klein Distilled Model: {REPO_ID_DISTILLED}...")
23
- pipe = Flux2KleinPipeline.from_pretrained(REPO_ID_DISTILLED, torch_dtype=dtype)
24
- pipe.to(device)
25
  print("Model loaded successfully.")
26
 
27
- # --- LoRA Adapter Configuration ---
28
- ADAPTER_SPECS = {
29
- "Arcane": {
30
- "repo": "DeverStyle/Flux.2-Klein-Loras",
31
- "weights": "dever_arcane_flux2_klein_9b.safetensors",
32
- "adapter_name": "Arcane"
33
- },
34
- "Zoom": {
35
- "repo": "fal/flux-2-klein-4B-zoom-lora",
36
- "weights": "flux-red-zoom-lora.safetensors",
37
- "adapter_name": "zoom"
38
- },
39
- "Background-Remove": {
40
- "repo": "fal/flux-2-klein-4B-background-remove-lora",
41
- "weights": "flux-background-remove-lora.safetensors",
42
- "adapter_name": "rmbg"
43
- },
44
- "Object-Remove": {
45
- "repo": "fal/flux-2-klein-4B-object-remove-lora",
46
- "weights": "flux-object-remove-lora.safetensors",
47
- "adapter_name": "object-remove"
48
- },
49
- "Sprite-Sheet": {
50
- "repo": "fal/flux-2-klein-4b-spritesheet-lora",
51
- "weights": "flux-spritesheet-lora.safetensors",
52
- "adapter_name": "spritesheet"
53
- },
54
  }
55
 
56
- # Keep track of which adapters have been downloaded and loaded
57
- LOADED_ADAPTERS = set()
58
-
59
- # --- Helper Functions ---
60
- def update_dimensions_from_image(image_list):
61
- """
62
- Update width/height sliders based on the first uploaded image's aspect ratio.
63
- """
64
- if not image_list:
65
- return 1024, 1024
66
-
67
- # Gallery returns a list of tuples: [(<PIL.Image.Image>,), ...]
68
- img = image_list[0][0]
69
- img_width, img_height = img.size
70
-
71
- aspect_ratio = img_width / img_height
72
-
73
- if aspect_ratio >= 1:
74
- new_width = MAX_IMAGE_SIZE
75
- new_height = int(MAX_IMAGE_SIZE / aspect_ratio)
76
- else:
77
- new_height = MAX_IMAGE_SIZE
78
- new_width = int(MAX_IMAGE_SIZE * aspect_ratio)
79
-
80
- # Ensure dimensions are multiples of 8
81
- new_width = (new_width // 8) * 8
82
- new_height = (new_height // 8) * 8
83
-
84
- return new_width, new_height
85
-
86
- # --- Core Inference Function ---
87
- @spaces.GPU(duration=90)
88
- def infer(
89
- prompt: str,
90
- lora_adapter: str,
91
- input_images=None,
92
- seed: int = 42,
93
- randomize_seed: bool = True,
94
- width: int = 1024,
95
- height: int = 1024,
96
- num_inference_steps: int = 4,
97
- guidance_scale: float = 1.0,
98
- progress=gr.Progress(track_tqdm=True)
99
- ):
100
- """
101
- Main function to generate or edit images using the FLUX.2 model and selected LoRA adapters.
102
- """
103
- gc.collect()
104
- torch.cuda.empty_cache()
105
-
106
- # --- LoRA Handling ---
107
- if lora_adapter != "None":
108
- spec = ADAPTER_SPECS.get(lora_adapter)
109
- if not spec:
110
- raise gr.Error(f"Configuration not found for adapter: {lora_adapter}")
111
-
112
- adapter_name = spec["adapter_name"]
113
 
114
- # Download and load the adapter if it's the first time being used
115
- if adapter_name not in LOADED_ADAPTERS:
116
- print(f"--- Downloading and Loading Adapter: {lora_adapter} ---")
117
- progress(0.1, desc=f"Loading LoRA: {lora_adapter}")
118
- try:
119
- pipe.load_lora_weights(
120
- spec["repo"],
121
- weight_name=spec["weights"],
122
- adapter_name=adapter_name
123
- )
124
- LOADED_ADAPTERS.add(adapter_name)
125
- print(f"--- Adapter {lora_adapter} loaded successfully. ---")
126
- except Exception as e:
127
- raise gr.Error(f"Failed to load adapter {lora_adapter}: {e}")
128
-
129
- # Set the active adapter
130
  pipe.set_adapters([adapter_name], adapter_weights=[1.0])
131
-
132
  else:
133
- # If "None" is selected, ensure no LoRAs are active
 
134
  pipe.disable_lora()
135
-
136
- # --- Seed ---
137
  if randomize_seed:
138
  seed = random.randint(0, MAX_SEED)
139
- generator = torch.Generator(device=device).manual_seed(seed)
140
-
141
- # --- Image Processing ---
142
- # Prepare image list from Gradio Gallery input
143
- image_list = None
144
- if input_images:
145
- image_list = [item[0] for item in input_images] # Extract PIL images from tuples
146
-
147
- # --- Generation ---
148
- progress(0.5, desc=f"Generating with seed {seed}...")
149
 
150
- pipe_kwargs = {
151
- "prompt": prompt,
152
- "height": height,
153
- "width": width,
154
- "num_inference_steps": num_inference_steps,
155
- "guidance_scale": guidance_scale,
156
- "generator": generator,
157
- }
158
 
159
- # Add images to the pipeline arguments only if they are provided
160
- if image_list:
161
- pipe_kwargs["image"] = image_list
 
 
 
 
 
 
162
 
163
- try:
164
- image = pipe(**pipe_kwargs).images[0]
165
- except Exception as e:
166
- # Unload the active LoRA on failure to allow retries or switching
167
- pipe.disable_lora()
168
- raise gr.Error(f"Inference failed: {e}")
169
-
170
- # --- Cleanup ---
171
- gc.collect()
172
- torch.cuda.empty_cache()
173
-
174
  return image, seed
175
 
176
  # --- UI Layout ---
177
- css = """
178
- #col-container { margin: 0 auto; max-width: 1200px; }
179
- .gallery-container img { object-fit: contain; }
180
  """
181
 
182
  with gr.Blocks() as demo:
183
  with gr.Column(elem_id="col-container"):
184
- gr.Markdown("# **FLUX.2-klein-LoRA-Studio**")
185
  gr.Markdown(
186
- "Generate and edit images using the distilled **FLUX.2-klein-9B** model. "
187
- "Select a specialized LoRA adapter from the dropdown for advanced editing tasks."
188
  )
189
 
190
- with gr.Row():
191
- with gr.Column(scale=2):
192
- with gr.Row():
193
- prompt = gr.Text(
194
- label="Prompt",
195
- show_label=False,
196
- max_lines=2,
197
- placeholder="Enter your prompt here...",
198
- container=False,
199
- scale=3
200
- )
201
- run_button = gr.Button("Run", scale=1, variant="primary")
202
-
203
- with gr.Accordion("Input Image(s) (for editing LoRAs)", open=False):
204
- input_images = gr.Gallery(
205
- label="Input Image(s)",
206
- type="pil",
207
- columns=3,
208
- rows=1,
209
- )
210
 
211
  lora_adapter = gr.Dropdown(
212
- label="Select LoRA Adapter (or None for base generation)",
213
- choices=["None"] + list(ADAPTER_SPECS.keys()),
214
- value="None"
 
215
  )
216
-
 
 
217
  with gr.Accordion("Advanced Settings", open=False):
218
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
219
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
220
 
221
- with gr.Row():
222
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=8, value=1024)
223
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=8, value=1024)
224
-
225
- with gr.Row():
226
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=4)
227
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=1.0)
228
-
229
- with gr.Column(scale=1):
230
- result = gr.Image(label="Result", show_label=False, height=512)
231
  used_seed = gr.Textbox(label="Used Seed", interactive=False)
232
-
 
233
  gr.Examples(
234
  examples=[
235
- ["A photorealistic, soaking wet capybara taking shelter under a large banana leaf in a rainy jungle, close up photo", "None"],
236
- ["A kawaii die-cut sticker of a chubby orange cat, big sparkly eyes, happy smile, paws raised, heart-shaped pink nose, smooth rounded lines, black outlines, soft gradient shading, pink cheeks.", "None"],
237
- ["A beautiful, majestic white horse running on a beach, cinematic lighting", "Zoom"],
238
- ["A corgi wearing a tiny backpack, hiking in the mountains", "Background-Remove"],
239
- ["A cute, round, fluffy creature with big eyes", "Sprite-Sheet"],
240
  ],
241
- fn=infer,
242
- inputs=[prompt, lora_adapter],
243
- outputs=[result, used_seed],
244
  cache_examples=False,
245
- label="Examples"
246
  )
247
-
248
- # --- Event Listeners ---
249
- # Trigger inference on button click or prompt submission
250
  run_button.click(
251
  fn=infer,
252
- inputs=[prompt, lora_adapter, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
253
- outputs=[result, used_seed]
254
- )
255
- prompt.submit(
256
- fn=infer,
257
- inputs=[prompt, lora_adapter, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
258
- outputs=[result, used_seed]
259
- )
260
-
261
- # Auto-update dimensions when an image is uploaded for editing
262
- input_images.upload(
263
- fn=update_dimensions_from_image,
264
- inputs=[input_images],
265
- outputs=[width, height]
266
  )
267
 
268
- # --- Launch the App ---
269
  if __name__ == "__main__":
270
- demo.queue(max_size=20).launch(css=css)
 
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
 
4
  import spaces
5
  import torch
6
+ import random
7
  from PIL import Image
8
+ from typing import Iterable
9
 
10
+ # Pipeline for FLUX.2 Klein
11
+ from diffusers import Flux2KleinPipeline
12
+ from diffusers.utils import load_image
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # --- Hardware and Theme Setup ---
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ from gradio.themes import Soft
19
+ from gradio.themes.utils import colors, fonts, sizes
20
+
21
+ colors.orange_red = colors.Color(
22
+ name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
23
+ c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000",
24
+ c900="#992900", c950="#802200",
25
+ )
26
+
27
+ class OrangeRedTheme(Soft):
28
+ def __init__(
29
+ self, *, primary_hue: colors.Color | str = colors.gray,
30
+ secondary_hue: colors.Color | str = colors.orange_red,
31
+ neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg,
32
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
33
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
34
+ ),
35
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
36
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
37
+ ),
38
+ ):
39
+ super().__init__(
40
+ primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue,
41
+ text_size=text_size, font=font, font_mono=font_mono,
42
+ )
43
+ super().set(
44
+ background_fill_primary="*primary_50",
45
+ background_fill_primary_dark="*primary_900",
46
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
47
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
48
+ button_primary_text_color="white",
49
+ button_primary_text_color_hover="white",
50
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
51
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
52
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
53
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
54
+ slider_color="*secondary_500",
55
+ slider_color_dark="*secondary_600",
56
+ block_title_text_weight="600", block_border_width="3px",
57
+ block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg",
58
+ button_large_padding="11px", color_accent_soft="*primary_100",
59
+ block_label_background_fill="*primary_200",
60
+ )
61
 
62
+ orange_red_theme = OrangeRedTheme()
63
  MAX_SEED = np.iinfo(np.int32).max
 
64
 
65
  # --- Model Loading ---
66
+ print("Loading FLUX.2 Klein 9B model...")
67
+ pipe = Flux2KleinPipeline.from_pretrained(
68
+ "black-forest-labs/FLUX.2-klein-9B",
69
+ torch_dtype=torch.bfloat16
70
+ ).to(device)
 
71
  print("Model loaded successfully.")
72
 
73
+ # --- LoRA Loading (Updated) ---
74
+ print("Loading new LoRA adapters...")
75
+ pipe.load_lora_weights(
76
+ "starsfriday/FLUX.2-klein-AC-Style-LORA",
77
+ weight_name="flux2_klein_lowres.safetensors",
78
+ adapter_name="american_comic_style"
79
+ )
80
+ pipe.load_lora_weights(
81
+ "linoyts/Flux2-Klein-Delight-LoRA",
82
+ weight_name="pytorch_lora_weights_v2.safetensors",
83
+ adapter_name="klein-delight"
84
+ )
85
+ print("All LoRA adapters loaded.")
86
+
87
+ # Updated map for the new adapters
88
+ ADAPTER_MAP = {
89
+ "American Comic Style": "american_comic_style",
90
+ "Klein Delight Style": "klein-delight",
 
 
 
 
 
 
 
 
 
91
  }
92
 
93
+ @spaces.GPU
94
+ def infer(input_image, prompt, lora_adapter, seed=42, randomize_seed=True, guidance_scale=4.0, steps=4, progress=gr.Progress(track_tqdm=True)):
95
+ # Input image is required for image-to-image tasks
96
+ if not input_image:
97
+ raise gr.Error("Please upload an image to apply a style to.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Dynamically set the adapter based on the dropdown choice
100
+ adapter_name = ADAPTER_MAP.get(lora_adapter)
101
+ if adapter_name:
102
+ print(f"Activating LoRA: {lora_adapter} ({adapter_name})")
 
 
 
 
 
 
 
 
 
 
 
 
103
  pipe.set_adapters([adapter_name], adapter_weights=[1.0])
 
104
  else:
105
+ # If "None" is selected (or an invalid choice), disable LoRAs
106
+ print("No LoRA selected. Disabling adapters.")
107
  pipe.disable_lora()
108
+
 
109
  if randomize_seed:
110
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
111
 
112
+ original_image = input_image.copy().convert("RGB")
 
 
 
 
 
 
 
113
 
114
+ image = pipe(
115
+ image=original_image,
116
+ prompt=prompt,
117
+ guidance_scale=guidance_scale,
118
+ width=original_image.size[0],
119
+ height=original_image.size[1],
120
+ num_inference_steps=steps,
121
+ generator=torch.Generator(device=device).manual_seed(seed),
122
+ ).images[0]
123
 
124
+ return image, seed
125
+
126
+ @spaces.GPU
127
+ def infer_example(input_image, prompt, lora_adapter):
128
+ # Use a fixed seed for reproducible examples
129
+ image, seed = infer(input_image, prompt, lora_adapter, seed=12345, randomize_seed=False)
 
 
 
 
 
130
  return image, seed
131
 
132
  # --- UI Layout ---
133
+ css="""
134
+ #col-container { margin: 0 auto; max-width: 960px; }
135
+ #main-title h1 { font-size: 2.2em !important; }
136
  """
137
 
138
  with gr.Blocks() as demo:
139
  with gr.Column(elem_id="col-container"):
140
+ gr.Markdown("# **FLUX.2 Klein LoRA Stylizer**", elem_id="main-title")
141
  gr.Markdown(
142
+ "Apply creative styles to your images using **FLUX.2-klein-9B** and specialized LoRA adapters. "
143
+ "Upload an image, select a style, and write a prompt to guide the transformation."
144
  )
145
 
146
+ with gr.Row(equal_height=True):
147
+ with gr.Column():
148
+ input_image = gr.Image(label="Upload Image", type="pil", height=290, sources=["upload", "webcam", "clipboard"])
149
+ prompt = gr.Text(label="Guiding Prompt", show_label=True, placeholder="e.g., a man with a red superhero mask")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  lora_adapter = gr.Dropdown(
152
+ label="Choose a Creative Style",
153
+ # Updated choices for the new adapters
154
+ choices=["American Comic Style", "Klein Delight Style"],
155
+ value="American Comic Style"
156
  )
157
+
158
+ run_button = gr.Button("Apply Style", variant="primary")
159
+
160
  with gr.Accordion("Advanced Settings", open=False):
161
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
162
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
163
+ # Updated defaults suitable for FLUX.2 Klein
164
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=4.0)
165
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=4, step=1)
166
 
167
+ with gr.Column():
168
+ output_image = gr.Image(label="Stylized Image", interactive=False, format="png", height=450)
 
 
 
 
 
 
 
 
169
  used_seed = gr.Textbox(label="Used Seed", interactive=False)
170
+
171
+ # Updated examples for the new LoRAs
172
  gr.Examples(
173
  examples=[
174
+ ["examples/portrait_man.jpg", "a man with a rugged beard, pop art style, bold lines, heavy shading", "American Comic Style"],
175
+ ["examples/cityscape.jpg", "a futuristic city, vibrant colors, clean lines, delightful style", "Klein Delight Style"],
176
+ ["examples/portrait_woman.jpg", "a woman with glasses, comic book art, detailed ink work, speech bubble", "American Comic Style"],
177
+ ["examples/animal.jpg", "a cute red panda, charming and delightful illustration, soft lighting", "Klein Delight Style"],
 
178
  ],
179
+ inputs=[input_image, prompt, lora_adapter],
180
+ outputs=[output_image, used_seed],
181
+ fn=infer_example,
182
  cache_examples=False,
 
183
  )
184
+
 
 
185
  run_button.click(
186
  fn=infer,
187
+ inputs=[input_image, prompt, lora_adapter, seed, randomize_seed, guidance_scale, steps],
188
+ outputs=[output_image, used_seed]
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
 
191
  if __name__ == "__main__":
192
+ demo.queue().launch(css=css, theme=orange_red_theme, show_error=True)