File size: 12,833 Bytes
4178132
 
52a7f9d
e9e87c2
0fb92a2
4178132
 
 
c0bae76
d857d88
4178132
01e2da5
4178132
0fb92a2
4178132
 
 
 
d857d88
f353743
01e2da5
 
 
c0bae76
 
f353743
c0bae76
01e2da5
 
 
4178132
01e2da5
 
 
 
 
 
15a8a0f
 
2ed80c7
4178132
 
 
 
 
fee578b
 
 
8dbfd67
01e2da5
 
 
 
 
 
 
8dbfd67
01e2da5
8dbfd67
4178132
01e2da5
 
 
 
4178132
 
2ed80c7
4178132
 
f353743
4178132
26cfe11
 
 
 
 
 
0fb92a2
3ec12ee
26cfe11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9e87c2
c0bae76
4178132
c0bae76
2ed80c7
4178132
c0bae76
 
4178132
c0bae76
 
4178132
f353743
26cfe11
f353743
26cfe11
01e2da5
f353743
26cfe11
f353743
 
26cfe11
01e2da5
0fb92a2
4178132
 
 
 
 
 
 
 
 
 
0fb92a2
f353743
 
 
d189df7
26cfe11
01e2da5
2ed80c7
01e2da5
4178132
26cfe11
 
 
 
 
 
01e2da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f353743
01e2da5
 
 
 
 
 
 
 
 
 
 
 
 
4178132
01e2da5
 
d857d88
26cfe11
01e2da5
26cfe11
4178132
01e2da5
 
26cfe11
01e2da5
 
 
4178132
01e2da5
 
4178132
01e2da5
 
26cfe11
4178132
01e2da5
 
 
 
 
 
 
 
26cfe11
 
6cb7846
01e2da5
 
26cfe11
 
01e2da5
 
 
 
 
 
 
7bdaff8
26cfe11
01e2da5
 
 
26cfe11
01e2da5
 
4178132
01e2da5
 
 
 
 
 
 
 
 
 
26cfe11
01e2da5
 
 
 
 
 
7bdaff8
01e2da5
0fb92a2
26cfe11
f353743
 
 
0fb92a2
4178132
f353743
0fb92a2
4178132
 
 
 
 
 
 
 
f2bb007
0fb92a2
4178132
 
 
0fb92a2
4178132
0fb92a2
f353743
 
 
 
 
 
 
 
 
 
22ccffa
a058003
f353743
 
0fb92a2
f353743
 
 
4178132
f353743
22ccffa
f353743
4178132
f353743
4178132
f353743
4178132
f353743
4178132
 
f353743
 
 
 
4178132
f353743
22ccffa
f353743
4178132
 
f353743
4178132
f353743
 
 
4178132
f353743
 
 
 
4178132
 
 
f353743
26cfe11
4178132
 
f353743
4178132
 
 
 
 
 
 
 
f2bb007
 
f353743
4178132
 
26cfe11
 
4178132
52a7f9d
 
e9e87c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import os
import cv2
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw

import spaces
from huggingface_hub import hf_hub_download

# ==========================================
# 1. Global Settings & Variables
# ==========================================
MODEL_ID = "black-forest-labs/FLUX.1-dev"
DEBLUR_LORA_PATH = "."  
DEBLUR_WEIGHT_NAME = "deblurNet.safetensors"
BOKEH_LORA_DIR = "." 
BOKEH_WEIGHT_NAME = "bokehNet.safetensors"

# Global variables
pipe_flux = None
depth_model = None
depth_transform = None

# ==========================================
# 2. Depth Pro Loader
# ==========================================
class DepthProLoader:
    def load(self, device):
        print("πŸ”„ Loading Depth Pro model...")
        try:
            global Condition, generate, seed_everything, FluxPipeline, depth_pro
            from Genfocus.pipeline.flux import Condition, generate, seed_everything, FluxPipeline
            import depth_pro
            from depth_pro.depth_pro import DEFAULT_MONODEPTH_CONFIG_DICT
            import copy

            WEIGHTS_REPO_ID = "nycu-cplab/Genfocus-Model"
            DEPTH_FILENAME = "checkpoints/depth_pro.pt"
            
            checkpoint_path = hf_hub_download(
                repo_id=WEIGHTS_REPO_ID,
                filename=DEPTH_FILENAME,
                repo_type="model"
            )

            cfg = copy.deepcopy(DEFAULT_MONODEPTH_CONFIG_DICT)
            cfg.checkpoint_uri = checkpoint_path
            
            try:
                create_fn = depth_pro.create_model_and_transforms
            except AttributeError:
                from depth_pro.depth_pro import create_model_and_transforms
                create_fn = create_model_and_transforms

            model, transform = create_fn(
                config=cfg,
                device=device,
                precision=torch.float32 
            )
            model.eval()
            print(f"βœ… Depth Pro loaded on {device}.")
            return model, transform

        except Exception as e:
            print(f"❌ Failed to load Depth Pro: {e}")
            raise e

# ==========================================
# 3. Helper Functions
# ==========================================

def resize_and_crop_to_16(img: Image.Image) -> Image.Image:
    """
    1. Resize the longer side to 512, maintaining aspect ratio.
    2. Crop the dimensions to be multiples of 16.
    """
    w, h = img.size
    target = 512

    # 1. Resize longer side to 512
    if w >= h:
        scale = target / w
    else:
        scale = target / h
    
    new_w = int(w * scale)
    new_h = int(h * scale)
    
    img = img.resize((new_w, new_h), Image.LANCZOS)
    
    # 2. Crop to multiples of 16
    final_w = (new_w // 16) * 16
    final_h = (new_h // 16) * 16
    
    # Center crop calculation
    left = (new_w - final_w) // 2
    top = (new_h - final_h) // 2
    right = left + final_w
    bottom = top + final_h
    
    img = img.crop((left, top, right, bottom))
    return img

def switch_lora_on_gpu(pipe, target_mode):
    print(f"πŸ”„ Switching LoRA to [{target_mode}]...")
    pipe.unload_lora_weights() 
    
    if target_mode == "deblur":
        pipe.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring")
        pipe.set_adapters(["deblurring"])
    elif target_mode == "bokeh":
        pipe.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh")
        pipe.set_adapters(["bokeh"])

def preprocess_input_image(raw_img):
    """
    Always enforces resizing to 512 (long edge) and cropping to 16x.
    """
    if raw_img is None: return None, None
    print(f"πŸ”„ Preprocessing Input... Enforcing Resize.")
    
    # Always resize and crop
    final_input = resize_and_crop_to_16(raw_img)

    return final_input, final_input

def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
    if clean_img is None: return None, None
    img_copy = clean_img.copy()
    draw = ImageDraw.Draw(img_copy)
    x, y = evt.index
    r = 8 
    draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2)
    draw.line((x-r, y, x+r, y), fill="red", width=2)
    draw.line((x, y-r, x, y+r), fill="red", width=2)
    return img_copy, evt.index

# ==========================================
# 4. Main Pipeline
# ==========================================
@spaces.GPU(duration=120) 
def run_genfocus_pipeline(clean_input, click_coords, K_value):
    global pipe_flux, depth_model, depth_transform
    
    device = "cuda"
    
    if clean_input is None:
        raise gr.Error("Please complete Step 1 (Upload Image) first.")
    
    W_dyn, H_dyn = clean_input.size
    print(f"πŸ“ Processing Image Size: {W_dyn}x{H_dyn}")

    if pipe_flux is None:
        print("πŸš€ Loading FLUX to GPU (First Run)...")
        from Genfocus.pipeline.flux import FluxPipeline
        pipe_flux = FluxPipeline.from_pretrained(
            MODEL_ID, 
            torch_dtype=torch.bfloat16,
            token=os.getenv("HF_TOKEN")
        ).to(device)
    else:
        try:
            _ = pipe_flux.device.type
            pipe_flux.to(device)
        except Exception:
            print("⚠️ GPU Context changed, reloading FLUX...")
            from Genfocus.pipeline.flux import FluxPipeline
            pipe_flux = FluxPipeline.from_pretrained(
                MODEL_ID, 
                torch_dtype=torch.bfloat16,
                token=os.getenv("HF_TOKEN")
            ).to(device)

    # --- Load Depth Pro ---
    depth_loader = DepthProLoader()
    if depth_model is None:
        depth_model, depth_transform = depth_loader.load(device=device)
    else:
        try:
            depth_model = depth_model.to(device)
        except Exception:
            print("⚠️ GPU Context changed, reloading Depth Pro...")
            depth_model, depth_transform = depth_loader.load(device=device)

    from Genfocus.pipeline.flux import Condition, generate, seed_everything

    print("⚑ Running Inference...")
    
    # STAGE 1: DEBLUR
    switch_lora_on_gpu(pipe_flux, "deblur")
    
    condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0))
    cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
    cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0)
    
    seed_everything(42)
    deblurred_img = generate(
        pipe_flux, height=H_dyn, width=W_dyn,
        prompt="a sharp photo with everything in focus",
        conditions=[cond0, cond1]
    ).images[0]
    
    if K_value == 0:
        return deblurred_img

    # STAGE 2: BOKEH
    if click_coords is None:
        click_coords = [W_dyn // 2, H_dyn // 2]

    # Depth Estimation
    img_t = depth_transform(deblurred_img).to(device)
    with torch.no_grad():
        pred = depth_model.infer(img_t, f_px=None)
    
    depth_map = pred["depth"].cpu().numpy().squeeze()
    safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
    disp_orig = 1.0 / safe_depth
    # Resize disp to match current image dimensions
    disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR)

    # Defocus Map
    tx, ty = click_coords
    tx = min(max(int(tx), 0), W_dyn - 1)
    ty = min(max(int(ty), 0), H_dyn - 1)
    
    disp_focus = float(disp[ty, tx])
    dmf = disp - np.float32(disp_focus)
    defocus_abs = np.abs(K_value * dmf)
    MAX_COC = 100.0
    defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
    cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0) 

    # Generate New Latents
    seed_everything(42)
    gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
    current_latents, _ = pipe_flux.prepare_latents(
        batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn,
        dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
    )

    # Generate Bokeh
    switch_lora_on_gpu(pipe_flux, "bokeh")
    cond_img = Condition(deblurred_img, "bokeh") 
    cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True)
    
    seed_everything(42)
    gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
    
    with torch.no_grad():
        res = generate(
            pipe_flux, height=H_dyn, width=W_dyn,
            prompt="an excellent photo with a large aperture",
            conditions=[cond_img, cond_dmf],
            guidance_scale=1.0, kv_cache=False, generator=gen,
            latents=current_latents, 
        )
    generated_bokeh = res.images[0]
    
    return generated_bokeh


# ==========================================
# 5. UI Setup
# ==========================================
css = """
#col-container { margin: 0 auto; max-width: 1400px; }
#output_image { min-height: 400px; }
"""
base_path = os.getcwd()
example_dir = os.path.join(base_path, "example")
valid_examples = []
if os.path.exists(example_dir):
    files = os.listdir(example_dir)
    for f in files:
        if f.lower().endswith(('.jpg', '.jpeg', '.png')):
            valid_examples.append([os.path.join(example_dir, f)])

with gr.Blocks(css=css) as demo:
    clean_processed_state = gr.State(value=None)
    click_coords_state = gr.State(value=None)
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# πŸ“· Genfocus Pipeline: Interactive Refocusing (HF Demo)")
        
        # --- Description & Guide ---
        gr.Markdown("""
        ### πŸ“– User Guide
        **Generative Refocusing** supports two main applications:
        
        * **All-In-Focus (AIF) Estimation:** Set **K = 0**. The model will restore the AIF image from the blurry input.
          
        * **Refocusing:** 1. **Click** on the subject you want to bring into focus in the **Step 2** image preview.  
          2. Increase **K** (Blur Strength) to generate realistic bokeh effects based on the scene's depth.
        
        > ⚠️ **Preprocessing Note:** Due to resource constraints in this demo, input images are **automatically resized** (longer edge = 512px).  
        > If you wish to perform inference at the **original resolution**, please refer to our **[GitHub Code](https://github.com/rayray9999/Genfocus)** to run it locally.
        """)
        
        with gr.Row():
            # --- Top Row: Inputs & Controls ---
            
            # [Step 1: Upload]
            with gr.Column(scale=1):
                gr.Markdown("### Step 1: Upload Image")
                gr.Markdown("Click an example or upload your own image.")
                
                input_raw = gr.Image(label="Raw Input Image", type="pil")
                
                if valid_examples:
                    gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples (Click to Load)")

            # [Step 2: Focus & Run]
            with gr.Column(scale=1):
                gr.Markdown("### Step 2: Set Focus & K")
                gr.Markdown("The image below shows the actual input for the model. **Click on the image** to set the focus point.")
                
                focus_preview_img = gr.Image(label="Model Input (Processed) - Click Here", type="pil", interactive=False)
                
                with gr.Row():
                    click_status = gr.Textbox(label="Selected Coordinates", value="Center (Default)", interactive=False, scale=1)
                    k_slider = gr.Slider(minimum=0, maximum=50, value=20, step=1, label="Blur Strength (K)", scale=2)
                
                run_btn = gr.Button("✨ Run Genfocus", variant="primary", scale=1)

        # --- Bottom Row: Output ---
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Result")
                output_img = gr.Image(label="Final Output", type="pil", interactive=False, elem_id="output_image")

        # ==================== Event Handling ====================
        
        # 1. Update Preview (Removed resize_chk)
        update_trigger = [input_raw.change, input_raw.upload]
        for trigger in update_trigger:
            trigger(
                fn=preprocess_input_image,
                inputs=[input_raw],
                outputs=[focus_preview_img, clean_processed_state]
            )

        # 2. Draw Red Dot on Click
        focus_preview_img.select(
            fn=draw_red_dot_on_preview,
            inputs=[clean_processed_state],
            outputs=[focus_preview_img, click_coords_state]
        ).then(
            fn=lambda x: f"x={x[0]}, y={x[1]}",
            inputs=[click_coords_state],
            outputs=[click_status]
        )

        # 3. Run Pipeline
        run_btn.click(
            fn=run_genfocus_pipeline,
            inputs=[clean_processed_state, click_coords_state, k_slider],
            outputs=[output_img]
        )

if __name__ == "__main__":
    demo.launch()