yeq6x commited on
Commit
5540810
·
1 Parent(s): 826fa01

gr.Galleryからgr.Imageに変更します。

Browse files
Files changed (1) hide show
  1. app.py +19 -26
app.py CHANGED
@@ -81,7 +81,7 @@ MAX_SEED = np.iinfo(np.int32).max
81
  # --- Main Inference Function (Combined LoRA) ---
82
  @spaces.GPU()
83
  def infer(
84
- images,
85
  seed=42,
86
  randomize_seed=False,
87
  true_guidance_scale=1.0,
@@ -96,7 +96,7 @@ def infer(
96
  Run single inference with combined LoRAs: Lightning + Stage1 + Stage2.
97
 
98
  Parameters:
99
- images (list): Input images from the Gradio gallery (PIL or path-based).
100
  seed (int): Random seed for reproducibility.
101
  randomize_seed (bool): If True, overrides seed with a random value.
102
  true_guidance_scale (float): CFG scale used by Qwen-Image.
@@ -108,7 +108,7 @@ def infer(
108
  progress: Gradio progress callback.
109
 
110
  Returns:
111
- tuple: (result_images, seed_used)
112
  """
113
 
114
  # Hardcode the negative prompt
@@ -120,19 +120,13 @@ def infer(
120
  # Set up the generator for reproducibility
121
  generator = torch.Generator(device=device).manual_seed(seed)
122
 
123
- # Load input images into PIL Images
124
- pil_images = []
125
- if images is not None:
126
- for item in images:
127
- try:
128
- if isinstance(item[0], Image.Image):
129
- pil_images.append(item[0].convert("RGB"))
130
- elif isinstance(item[0], str):
131
- pil_images.append(Image.open(item[0]).convert("RGB"))
132
- elif hasattr(item, "name"):
133
- pil_images.append(Image.open(item.name).convert("RGB"))
134
- except Exception:
135
- continue
136
 
137
  if height==256 and width==256:
138
  height, width = None, None
@@ -147,7 +141,7 @@ def infer(
147
  pipe.set_adapters(["lightning", "stage1", "stage2"], adapter_weights=[1.0, stage1_weight, stage2_weight])
148
 
149
  result_images = pipe(
150
- image=pil_images if len(pil_images) > 0 else None,
151
  prompt=STAGE1_PROMPT,
152
  height=height,
153
  width=width,
@@ -158,8 +152,8 @@ def infer(
158
  num_images_per_prompt=1,
159
  ).images
160
 
161
- # Return result images and seed
162
- return result_images, seed
163
 
164
  # --- Examples and UI Layout ---
165
  examples = []
@@ -185,15 +179,14 @@ with gr.Blocks(css=css) as demo:
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
  gr.Markdown("### 📥 Input")
188
- input_images = gr.Gallery(label="Input Images",
189
- show_label=False,
190
- type="pil",
191
- interactive=True,
192
- object_fit="contain")
193
 
194
  with gr.Column(scale=1):
195
  gr.Markdown("### 📤 Result")
196
- result = gr.Gallery(label="Result", show_label=False, type="pil", interactive=False, object_fit="contain")
197
 
198
  run_button = gr.Button("🚀 Generate", variant="primary", size="lg")
199
 
@@ -263,7 +256,7 @@ with gr.Blocks(css=css) as demo:
263
  run_button.click(
264
  fn=infer,
265
  inputs=[
266
- input_images,
267
  seed,
268
  randomize_seed,
269
  true_guidance_scale,
 
81
  # --- Main Inference Function (Combined LoRA) ---
82
  @spaces.GPU()
83
  def infer(
84
+ image,
85
  seed=42,
86
  randomize_seed=False,
87
  true_guidance_scale=1.0,
 
96
  Run single inference with combined LoRAs: Lightning + Stage1 + Stage2.
97
 
98
  Parameters:
99
+ image: Input image (PIL Image or path string).
100
  seed (int): Random seed for reproducibility.
101
  randomize_seed (bool): If True, overrides seed with a random value.
102
  true_guidance_scale (float): CFG scale used by Qwen-Image.
 
108
  progress: Gradio progress callback.
109
 
110
  Returns:
111
+ tuple: (result_image, seed_used)
112
  """
113
 
114
  # Hardcode the negative prompt
 
120
  # Set up the generator for reproducibility
121
  generator = torch.Generator(device=device).manual_seed(seed)
122
 
123
+ # Load input image into PIL Image
124
+ pil_image = None
125
+ if image is not None:
126
+ if isinstance(image, Image.Image):
127
+ pil_image = image.convert("RGB")
128
+ elif isinstance(image, str):
129
+ pil_image = Image.open(image).convert("RGB")
 
 
 
 
 
 
130
 
131
  if height==256 and width==256:
132
  height, width = None, None
 
141
  pipe.set_adapters(["lightning", "stage1", "stage2"], adapter_weights=[1.0, stage1_weight, stage2_weight])
142
 
143
  result_images = pipe(
144
+ image=[pil_image] if pil_image is not None else None,
145
  prompt=STAGE1_PROMPT,
146
  height=height,
147
  width=width,
 
152
  num_images_per_prompt=1,
153
  ).images
154
 
155
+ # Return first result image and seed
156
+ return result_images[0] if result_images else None, seed
157
 
158
  # --- Examples and UI Layout ---
159
  examples = []
 
179
  with gr.Row():
180
  with gr.Column(scale=1):
181
  gr.Markdown("### 📥 Input")
182
+ input_image = gr.Image(label="Input Image",
183
+ show_label=False,
184
+ type="pil",
185
+ interactive=True)
 
186
 
187
  with gr.Column(scale=1):
188
  gr.Markdown("### 📤 Result")
189
+ result = gr.Image(label="Result", show_label=False, type="pil", interactive=False)
190
 
191
  run_button = gr.Button("🚀 Generate", variant="primary", size="lg")
192
 
 
256
  run_button.click(
257
  fn=infer,
258
  inputs=[
259
+ input_image,
260
  seed,
261
  randomize_seed,
262
  true_guidance_scale,