Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -56,13 +56,24 @@ Rules:
|
|
| 56 |
|
| 57 |
Output only the final instruction in plain text and nothing else."""
|
| 58 |
|
| 59 |
-
#
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def image_to_data_uri(img):
|
|
@@ -143,11 +154,14 @@ def update_dimensions_from_image(image_list):
|
|
| 143 |
|
| 144 |
|
| 145 |
@spaces.GPU(duration=85)
|
| 146 |
-
def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
|
| 147 |
|
| 148 |
if randomize_seed:
|
| 149 |
seed = random.randint(0, MAX_SEED)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
# Prepare image list (convert None or empty gallery to None)
|
| 152 |
image_list = None
|
| 153 |
if input_images is not None and len(input_images) > 0:
|
|
@@ -164,7 +178,7 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
|
|
| 164 |
print(f"Upsampled Prompt: {final_prompt}")
|
| 165 |
|
| 166 |
# 2. Image Generation
|
| 167 |
-
progress(0.2, desc="Generating image...")
|
| 168 |
|
| 169 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 170 |
|
|
@@ -236,6 +250,13 @@ FLUX.2 Klein [dev] is a distilled model capable of generating, editing and combi
|
|
| 236 |
)
|
| 237 |
|
| 238 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
prompt_upsampling = gr.Checkbox(
|
| 240 |
label="Prompt Upsampling",
|
| 241 |
value=True,
|
|
@@ -321,7 +342,7 @@ FLUX.2 Klein [dev] is a distilled model capable of generating, editing and combi
|
|
| 321 |
gr.on(
|
| 322 |
triggers=[run_button.click, prompt.submit],
|
| 323 |
fn=infer,
|
| 324 |
-
inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
|
| 325 |
outputs=[result, seed]
|
| 326 |
)
|
| 327 |
|
|
|
|
| 56 |
|
| 57 |
Output only the final instruction in plain text and nothing else."""
|
| 58 |
|
| 59 |
+
# Model repository IDs
|
| 60 |
+
REPO_ID_4B = "diffusers-internal-dev/dummy-1015-4b" # 4b model
|
| 61 |
+
REPO_ID_9B = "diffusers-internal-dev/dummy-1015-9b" # 9b model
|
| 62 |
+
|
| 63 |
+
# Load both models
|
| 64 |
+
print("Loading 4B model...")
|
| 65 |
+
pipe_4b = Flux2KleinPipeline.from_pretrained(REPO_ID_4B, torch_dtype=dtype)
|
| 66 |
+
pipe_4b.to("cuda")
|
| 67 |
+
|
| 68 |
+
print("Loading 9B model...")
|
| 69 |
+
pipe_9b = Flux2KleinPipeline.from_pretrained(REPO_ID_9B, torch_dtype=dtype)
|
| 70 |
+
pipe_9b.to("cuda")
|
| 71 |
+
|
| 72 |
+
# Dictionary for easy access
|
| 73 |
+
pipes = {
|
| 74 |
+
"4B": pipe_4b,
|
| 75 |
+
"9B": pipe_9b,
|
| 76 |
+
}
|
| 77 |
|
| 78 |
|
| 79 |
def image_to_data_uri(img):
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
@spaces.GPU(duration=85)
|
| 157 |
+
def infer(prompt, input_images=None, model_choice="4B", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=4.0, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
|
| 158 |
|
| 159 |
if randomize_seed:
|
| 160 |
seed = random.randint(0, MAX_SEED)
|
| 161 |
|
| 162 |
+
# Select the appropriate pipeline based on model choice
|
| 163 |
+
pipe = pipes[model_choice]
|
| 164 |
+
|
| 165 |
# Prepare image list (convert None or empty gallery to None)
|
| 166 |
image_list = None
|
| 167 |
if input_images is not None and len(input_images) > 0:
|
|
|
|
| 178 |
print(f"Upsampled Prompt: {final_prompt}")
|
| 179 |
|
| 180 |
# 2. Image Generation
|
| 181 |
+
progress(0.2, desc=f"Generating image with {model_choice} model...")
|
| 182 |
|
| 183 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 184 |
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
with gr.Accordion("Advanced Settings", open=False):
|
| 253 |
+
model_choice = gr.Radio(
|
| 254 |
+
label="Model Size",
|
| 255 |
+
choices=["4B", "9B"],
|
| 256 |
+
value="4B",
|
| 257 |
+
info="Choose between the 4B (faster) or 9B (higher quality) model"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
prompt_upsampling = gr.Checkbox(
|
| 261 |
label="Prompt Upsampling",
|
| 262 |
value=True,
|
|
|
|
| 342 |
gr.on(
|
| 343 |
triggers=[run_button.click, prompt.submit],
|
| 344 |
fn=infer,
|
| 345 |
+
inputs=[prompt, input_images, model_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, prompt_upsampling],
|
| 346 |
outputs=[result, seed]
|
| 347 |
)
|
| 348 |
|