multimodalart HF Staff commited on
Commit
cfae2bb
·
verified ·
1 Parent(s): c31b7c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -56,13 +56,24 @@ Rules:
56
 
57
  Output only the final instruction in plain text and nothing else."""
58
 
59
- # Load model - using Flux2KleinPipeline with built-in text encoder
60
- repo_id = "diffusers-internal-dev/dummy-1015-4b" # 4b model
61
- # repo_id = "diffusers-internal-dev/dummy-1015-9b" # 9b model (alternative)
62
-
63
- pipe = Flux2KleinPipeline.from_pretrained(repo_id, torch_dtype=dtype)
64
- pipe.to("cuda")
65
- #pipe.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  def image_to_data_uri(img):
@@ -143,11 +154,14 @@ def update_dimensions_from_image(image_list):
143
 
144
 
145
  @spaces.GPU(duration=85)
146
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
147
 
148
  if randomize_seed:
149
  seed = random.randint(0, MAX_SEED)
150
 
 
 
 
151
  # Prepare image list (convert None or empty gallery to None)
152
  image_list = None
153
  if input_images is not None and len(input_images) > 0:
@@ -164,7 +178,7 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
164
  print(f"Upsampled Prompt: {final_prompt}")
165
 
166
  # 2. Image Generation
167
- progress(0.2, desc="Generating image...")
168
 
169
  generator = torch.Generator(device=device).manual_seed(seed)
170
 
@@ -236,6 +250,13 @@ FLUX.2 Klein [dev] is a distilled model capable of generating, editing and combi
236
  )
237
 
238
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
239
  prompt_upsampling = gr.Checkbox(
240
  label="Prompt Upsampling",
241
  value=True,
@@ -321,7 +342,7 @@ FLUX.2 Klein [dev] is a distilled model capable of generating, editing and combi
321
  gr.on(
322
  triggers=[run_button.click, prompt.submit],
323
  fn=infer,
324
- inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
325
  outputs=[result, seed]
326
  )
327
 
 
56
 
57
  Output only the final instruction in plain text and nothing else."""
58
 
59
+ # Model repository IDs
60
+ REPO_ID_4B = "diffusers-internal-dev/dummy-1015-4b" # 4b model
61
+ REPO_ID_9B = "diffusers-internal-dev/dummy-1015-9b" # 9b model
62
+
63
+ # Load both models
64
+ print("Loading 4B model...")
65
+ pipe_4b = Flux2KleinPipeline.from_pretrained(REPO_ID_4B, torch_dtype=dtype)
66
+ pipe_4b.to("cuda")
67
+
68
+ print("Loading 9B model...")
69
+ pipe_9b = Flux2KleinPipeline.from_pretrained(REPO_ID_9B, torch_dtype=dtype)
70
+ pipe_9b.to("cuda")
71
+
72
+ # Dictionary for easy access
73
+ pipes = {
74
+ "4B": pipe_4b,
75
+ "9B": pipe_9b,
76
+ }
77
 
78
 
79
  def image_to_data_uri(img):
 
154
 
155
 
156
  @spaces.GPU(duration=85)
157
+ def infer(prompt, input_images=None, model_choice="4B", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
158
 
159
  if randomize_seed:
160
  seed = random.randint(0, MAX_SEED)
161
 
162
+ # Select the appropriate pipeline based on model choice
163
+ pipe = pipes[model_choice]
164
+
165
  # Prepare image list (convert None or empty gallery to None)
166
  image_list = None
167
  if input_images is not None and len(input_images) > 0:
 
178
  print(f"Upsampled Prompt: {final_prompt}")
179
 
180
  # 2. Image Generation
181
+ progress(0.2, desc=f"Generating image with {model_choice} model...")
182
 
183
  generator = torch.Generator(device=device).manual_seed(seed)
184
 
 
250
  )
251
 
252
  with gr.Accordion("Advanced Settings", open=False):
253
+ model_choice = gr.Radio(
254
+ label="Model Size",
255
+ choices=["4B", "9B"],
256
+ value="4B",
257
+ info="Choose between the 4B (faster) or 9B (higher quality) model"
258
+ )
259
+
260
  prompt_upsampling = gr.Checkbox(
261
  label="Prompt Upsampling",
262
  value=True,
 
342
  gr.on(
343
  triggers=[run_button.click, prompt.submit],
344
  fn=infer,
345
+ inputs=[prompt, input_images, model_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
346
  outputs=[result, seed]
347
  )
348