Ray commited on
Commit
26cfe11
ยท
1 Parent(s): d189df7

feat: add pipeline with LFS images

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import numpy as np
6
  from PIL import Image, ImageDraw
7
 
8
- # Hugging Face Spaces ็‰นๆœ‰็š„ GPU ่ฃ้ฃพๅ™จ
9
  import spaces
10
  from huggingface_hub import hf_hub_download
11
 
@@ -68,21 +67,41 @@ class DepthProLoader:
68
  raise e
69
 
70
  # ==========================================
71
- # 3. Helper Functions
72
  # ==========================================
73
- def center_crop_512(img: Image.Image) -> Image.Image:
 
 
 
 
 
74
  w, h = img.size
75
  target = 512
76
- if min(w, h) < target:
77
- scale = target / min(w, h)
78
- new_w, new_h = int(w * scale), int(h * scale)
79
- img = img.resize((new_w, new_h), Image.LANCZOS)
80
- w, h = new_w, new_h
81
- left = (w - target) // 2
82
- top = (h - target) // 2
83
- right = left + target
84
- bottom = top + target
85
- return img.crop((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def switch_lora_on_gpu(pipe, target_mode):
88
  print(f"๐Ÿ”„ Switching LoRA to [{target_mode}]...")
@@ -96,19 +115,36 @@ def switch_lora_on_gpu(pipe, target_mode):
96
  pipe.set_adapters(["bokeh"])
97
 
98
  def preprocess_input_image(raw_img, do_resize):
 
 
 
 
99
  if raw_img is None: return None, None
100
  print(f"๐Ÿ”„ Preprocessing Input... Resize={do_resize}")
 
 
101
  img_to_process = raw_img
 
102
  if do_resize:
 
 
 
103
  w, h = img_to_process.size
104
- scale = 512 / min(w, h)
105
- new_w, new_h = int(w * scale), int(h * scale)
106
- img_to_process = img_to_process.resize((new_w, new_h), Image.LANCZOS)
107
-
108
- final_input = center_crop_512(img_to_process)
109
- # ้€™่ฃกๅชๅ›žๅ‚ณๅ…ฉๅ€‹ๅ€ผ (ไธๅ†ๅ›žๅ‚ณ latents state)
 
110
  return final_input, final_input
111
 
 
 
 
 
 
 
112
  def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
113
  if clean_img is None: return None, None
114
  img_copy = clean_img.copy()
@@ -120,17 +156,21 @@ def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
120
  draw.line((x, y-r, x, y+r), fill="red", width=2)
121
  return img_copy, evt.index
122
 
123
- # ==========================================
124
- # 4. Main Pipeline
125
- # ==========================================
126
  @spaces.GPU(duration=120)
127
- def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
128
- # ็งป้™ค cached_latents ๅƒๆ•ธ
129
  global pipe_flux, depth_model, depth_transform
130
 
131
  device = "cuda"
132
 
133
- # --- 1. Load Flux ---
 
 
 
 
 
 
 
134
  if pipe_flux is None:
135
  print("๐Ÿš€ Loading FLUX to GPU (First Run)...")
136
  from Genfocus.pipeline.flux import FluxPipeline
@@ -163,10 +203,6 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
163
  print("โš ๏ธ GPU Context changed, reloading Depth Pro...")
164
  depth_model, depth_transform = depth_loader.load(device=device)
165
 
166
- # --- 3. Execution ---
167
- if clean_input_512 is None:
168
- raise gr.Error("Please complete Step 1 (Upload Image) first.")
169
-
170
  from Genfocus.pipeline.flux import Condition, generate, seed_everything
171
 
172
  print("โšก Running Inference...")
@@ -174,13 +210,14 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
174
  # STAGE 1: DEBLUR
175
  switch_lora_on_gpu(pipe_flux, "deblur")
176
 
177
- condition_0_img = Image.new("RGB", (512, 512), (0, 0, 0))
 
178
  cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
179
- cond1 = Condition(clean_input_512, "deblurring", [0, 0], 1.0)
180
 
181
  seed_everything(42)
182
  deblurred_img = generate(
183
- pipe_flux, height=512, width=512,
184
  prompt="a sharp photo with everything in focus",
185
  conditions=[cond0, cond1]
186
  ).images[0]
@@ -190,7 +227,8 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
190
 
191
  # STAGE 2: BOKEH
192
  if click_coords is None:
193
- click_coords = [256, 256]
 
194
 
195
  # Depth Estimation
196
  img_t = depth_transform(deblurred_img).to(device)
@@ -200,12 +238,14 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
200
  depth_map = pred["depth"].cpu().numpy().squeeze()
201
  safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
202
  disp_orig = 1.0 / safe_depth
203
- disp = cv2.resize(disp_orig, (512, 512), interpolation=cv2.INTER_LINEAR)
 
204
 
205
  # Defocus Map
206
  tx, ty = click_coords
207
- tx = min(max(int(tx), 0), 511)
208
- ty = min(max(int(ty), 0), 511)
 
209
 
210
  disp_focus = float(disp[ty, tx])
211
  dmf = disp - np.float32(disp_focus)
@@ -214,11 +254,12 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
214
  defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
215
  cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0)
216
 
217
- # Generate New Latents (Always fresh)
218
  seed_everything(42)
219
  gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
 
220
  current_latents, _ = pipe_flux.prepare_latents(
221
- batch_size=1, num_channels_latents=16, height=512, width=512,
222
  dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
223
  )
224
 
@@ -232,7 +273,7 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
232
 
233
  with torch.no_grad():
234
  res = generate(
235
- pipe_flux, height=512, width=512,
236
  prompt="an excellent photo with a large aperture",
237
  conditions=[cond_img, cond_dmf],
238
  guidance_scale=1.0, kv_cache=False, generator=gen,
@@ -242,9 +283,7 @@ def run_genfocus_pipeline(clean_input_512, click_coords, K_value):
242
 
243
  return generated_bokeh
244
 
245
- # ==========================================
246
- # 5. UI Setup
247
- # ==========================================
248
  css = """
249
  #col-container { margin: 0 auto; max-width: 1400px; }
250
  """
@@ -260,7 +299,6 @@ if os.path.exists(example_dir):
260
  with gr.Blocks(css=css) as demo:
261
  clean_processed_state = gr.State(value=None)
262
  click_coords_state = gr.State(value=None)
263
- # ็งป้™คไบ† latents_state
264
 
265
  with gr.Column(elem_id="col-container"):
266
  gr.Markdown("# ๐Ÿ“ท Genfocus Pipeline: Interactive Refocusing (HF Demo)")
@@ -269,7 +307,8 @@ with gr.Blocks(css=css) as demo:
269
  with gr.Column(scale=1):
270
  gr.Markdown("### Step 1: Upload & Preprocess")
271
  input_raw = gr.Image(label="Raw Input Image", type="pil")
272
- resize_chk = gr.Checkbox(label="Resize min edge to 512", value=False)
 
273
  if valid_examples:
274
  gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples")
275
 
@@ -289,7 +328,7 @@ with gr.Blocks(css=css) as demo:
289
  trigger(
290
  fn=preprocess_input_image,
291
  inputs=[input_raw, resize_chk],
292
- outputs=[focus_preview_img, clean_processed_state] # ็งป้™ค latents_state
293
  )
294
 
295
  focus_preview_img.select(
@@ -304,8 +343,8 @@ with gr.Blocks(css=css) as demo:
304
 
305
  run_btn.click(
306
  fn=run_genfocus_pipeline,
307
- inputs=[clean_processed_state, click_coords_state, k_slider], # ็งป้™ค latents_state
308
- outputs=[output_img] # ็งป้™ค latents_state
309
  )
310
 
311
  if __name__ == "__main__":
 
5
  import numpy as np
6
  from PIL import Image, ImageDraw
7
 
 
8
  import spaces
9
  from huggingface_hub import hf_hub_download
10
 
 
67
  raise e
68
 
69
  # ==========================================
70
+ # 3. Helper Functions (Modified)
71
  # ==========================================
72
+
73
+ def resize_and_crop_to_16(img: Image.Image) -> Image.Image:
74
+ """
75
+ 1. Resize the longer side to 512, maintaining aspect ratio.
76
+ 2. Crop the dimensions to be multiples of 16.
77
+ """
78
  w, h = img.size
79
  target = 512
80
+
81
+ # 1. Resize longer side to 512
82
+ if w >= h:
83
+ scale = target / w
84
+ else:
85
+ scale = target / h
86
+
87
+ new_w = int(w * scale)
88
+ new_h = int(h * scale)
89
+
90
+
91
+ img = img.resize((new_w, new_h), Image.LANCZOS)
92
+
93
+ # 2. Crop to multiples of 16
94
+ final_w = (new_w // 16) * 16
95
+ final_h = (new_h // 16) * 16
96
+
97
+ # Center crop calculation
98
+ left = (new_w - final_w) // 2
99
+ top = (new_h - final_h) // 2
100
+ right = left + final_w
101
+ bottom = top + final_h
102
+
103
+ img = img.crop((left, top, right, bottom))
104
+ return img
105
 
106
  def switch_lora_on_gpu(pipe, target_mode):
107
  print(f"๐Ÿ”„ Switching LoRA to [{target_mode}]...")
 
115
  pipe.set_adapters(["bokeh"])
116
 
117
  def preprocess_input_image(raw_img, do_resize):
118
+ """
119
+ ไฟฎๆ”นๅพŒ็š„้ ่™•็†๏ผš
120
+ ๅฆ‚ๆžœๅ‹พ้ธ do_resize (ๆˆ–้ ่จญ่กŒ็‚บ)๏ผŒๅ‰‡ๅŸท่กŒ้•ท้‚Š512+่ฃๅˆ‡16ๅ€ๆ•ธใ€‚
121
+ """
122
  if raw_img is None: return None, None
123
  print(f"๐Ÿ”„ Preprocessing Input... Resize={do_resize}")
124
+
125
+
126
  img_to_process = raw_img
127
+
128
  if do_resize:
129
+ final_input = resize_and_crop_to_16(img_to_process)
130
+ else:
131
+
132
  w, h = img_to_process.size
133
+ new_w = (w // 16) * 16
134
+ new_h = (h // 16) * 16
135
+ if new_w != w or new_h != h:
136
+ final_input = center_crop_helper(img_to_process, new_w, new_h)
137
+ else:
138
+ final_input = img_to_process
139
+
140
  return final_input, final_input
141
 
142
+ def center_crop_helper(img, target_w, target_h):
143
+ w, h = img.size
144
+ left = (w - target_w) // 2
145
+ top = (h - target_h) // 2
146
+ return img.crop((left, top, left + target_w, top + target_h))
147
+
148
  def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
149
  if clean_img is None: return None, None
150
  img_copy = clean_img.copy()
 
156
  draw.line((x, y-r, x, y+r), fill="red", width=2)
157
  return img_copy, evt.index
158
 
159
+
 
 
160
  @spaces.GPU(duration=120)
161
+ def run_genfocus_pipeline(clean_input, click_coords, K_value):
 
162
  global pipe_flux, depth_model, depth_transform
163
 
164
  device = "cuda"
165
 
166
+
167
+ if clean_input is None:
168
+ raise gr.Error("Please complete Step 1 (Upload Image) first.")
169
+
170
+ W_dyn, H_dyn = clean_input.size
171
+ print(f"๐Ÿ“ Processing Image Size: {W_dyn}x{H_dyn}")
172
+
173
+
174
  if pipe_flux is None:
175
  print("๐Ÿš€ Loading FLUX to GPU (First Run)...")
176
  from Genfocus.pipeline.flux import FluxPipeline
 
203
  print("โš ๏ธ GPU Context changed, reloading Depth Pro...")
204
  depth_model, depth_transform = depth_loader.load(device=device)
205
 
 
 
 
 
206
  from Genfocus.pipeline.flux import Condition, generate, seed_everything
207
 
208
  print("โšก Running Inference...")
 
210
  # STAGE 1: DEBLUR
211
  switch_lora_on_gpu(pipe_flux, "deblur")
212
 
213
+
214
+ condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0))
215
  cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
216
+ cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0)
217
 
218
  seed_everything(42)
219
  deblurred_img = generate(
220
+ pipe_flux, height=H_dyn, width=W_dyn,
221
  prompt="a sharp photo with everything in focus",
222
  conditions=[cond0, cond1]
223
  ).images[0]
 
227
 
228
  # STAGE 2: BOKEH
229
  if click_coords is None:
230
+ # Default to center if no click
231
+ click_coords = [W_dyn // 2, H_dyn // 2]
232
 
233
  # Depth Estimation
234
  img_t = depth_transform(deblurred_img).to(device)
 
238
  depth_map = pred["depth"].cpu().numpy().squeeze()
239
  safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
240
  disp_orig = 1.0 / safe_depth
241
+ # Resize disp to match current image dimensions
242
+ disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR)
243
 
244
  # Defocus Map
245
  tx, ty = click_coords
246
+ # Clamp coordinates to new dimensions
247
+ tx = min(max(int(tx), 0), W_dyn - 1)
248
+ ty = min(max(int(ty), 0), H_dyn - 1)
249
 
250
  disp_focus = float(disp[ty, tx])
251
  dmf = disp - np.float32(disp_focus)
 
254
  defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
255
  cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0)
256
 
257
+ # Generate New Latents
258
  seed_everything(42)
259
  gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
260
+ # Prepare latents with dynamic H, W
261
  current_latents, _ = pipe_flux.prepare_latents(
262
+ batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn,
263
  dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
264
  )
265
 
 
273
 
274
  with torch.no_grad():
275
  res = generate(
276
+ pipe_flux, height=H_dyn, width=W_dyn,
277
  prompt="an excellent photo with a large aperture",
278
  conditions=[cond_img, cond_dmf],
279
  guidance_scale=1.0, kv_cache=False, generator=gen,
 
283
 
284
  return generated_bokeh
285
 
286
+
 
 
287
  css = """
288
  #col-container { margin: 0 auto; max-width: 1400px; }
289
  """
 
299
  with gr.Blocks(css=css) as demo:
300
  clean_processed_state = gr.State(value=None)
301
  click_coords_state = gr.State(value=None)
 
302
 
303
  with gr.Column(elem_id="col-container"):
304
  gr.Markdown("# ๐Ÿ“ท Genfocus Pipeline: Interactive Refocusing (HF Demo)")
 
307
  with gr.Column(scale=1):
308
  gr.Markdown("### Step 1: Upload & Preprocess")
309
  input_raw = gr.Image(label="Raw Input Image", type="pil")
310
+
311
+ resize_chk = gr.Checkbox(label="Resize longer edge to 512 (crops to 16x)", value=True)
312
  if valid_examples:
313
  gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples")
314
 
 
328
  trigger(
329
  fn=preprocess_input_image,
330
  inputs=[input_raw, resize_chk],
331
+ outputs=[focus_preview_img, clean_processed_state]
332
  )
333
 
334
  focus_preview_img.select(
 
343
 
344
  run_btn.click(
345
  fn=run_genfocus_pipeline,
346
+ inputs=[clean_processed_state, click_coords_state, k_slider],
347
+ outputs=[output_img]
348
  )
349
 
350
  if __name__ == "__main__":
example/{get-out.jpg โ†’ 0.jpg} RENAMED
File without changes
example/{wweii_nurse.jpg โ†’ group_1.png} RENAMED
File without changes
example/kid.png ADDED

Git LFS Details

  • SHA256: 07d19b465ef526bee07495087bf187c4995027239a19c73d470dfca003a2f5e3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.91 MB