Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| # ========================================== | |
| # 1. Global Settings & Variables | |
| # ========================================== | |
| MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| DEBLUR_LORA_PATH = "." | |
| DEBLUR_WEIGHT_NAME = "deblurNet.safetensors" | |
| BOKEH_LORA_DIR = "." | |
| BOKEH_WEIGHT_NAME = "bokehNet.safetensors" | |
| # Global variables | |
| pipe_flux = None | |
| depth_model = None | |
| depth_transform = None | |
| # ========================================== | |
| # 2. Depth Pro Loader | |
| # ========================================== | |
| class DepthProLoader: | |
| def load(self, device): | |
| print("🔄 Loading Depth Pro model...") | |
| try: | |
| global Condition, generate, seed_everything, FluxPipeline, depth_pro | |
| from Genfocus.pipeline.flux import Condition, generate, seed_everything, FluxPipeline | |
| import depth_pro | |
| from depth_pro.depth_pro import DEFAULT_MONODEPTH_CONFIG_DICT | |
| import copy | |
| WEIGHTS_REPO_ID = "nycu-cplab/Genfocus-Model" | |
| DEPTH_FILENAME = "checkpoints/depth_pro.pt" | |
| checkpoint_path = hf_hub_download( | |
| repo_id=WEIGHTS_REPO_ID, | |
| filename=DEPTH_FILENAME, | |
| repo_type="model" | |
| ) | |
| cfg = copy.deepcopy(DEFAULT_MONODEPTH_CONFIG_DICT) | |
| cfg.checkpoint_uri = checkpoint_path | |
| try: | |
| create_fn = depth_pro.create_model_and_transforms | |
| except AttributeError: | |
| from depth_pro.depth_pro import create_model_and_transforms | |
| create_fn = create_model_and_transforms | |
| model, transform = create_fn( | |
| config=cfg, | |
| device=device, | |
| precision=torch.float32 | |
| ) | |
| model.eval() | |
| print(f"✅ Depth Pro loaded on {device}.") | |
| return model, transform | |
| except Exception as e: | |
| print(f"❌ Failed to load Depth Pro: {e}") | |
| raise e | |
| # ========================================== | |
| # 3. Helper Functions | |
| # ========================================== | |
| def resize_and_crop_to_16(img: Image.Image) -> Image.Image: | |
| """ | |
| 1. Resize the longer side to 512, maintaining aspect ratio. | |
| 2. Crop the dimensions to be multiples of 16. | |
| """ | |
| w, h = img.size | |
| target = 512 | |
| # 1. Resize longer side to 512 | |
| if w >= h: | |
| scale = target / w | |
| else: | |
| scale = target / h | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| img = img.resize((new_w, new_h), Image.LANCZOS) | |
| # 2. Crop to multiples of 16 | |
| final_w = (new_w // 16) * 16 | |
| final_h = (new_h // 16) * 16 | |
| # Center crop calculation | |
| left = (new_w - final_w) // 2 | |
| top = (new_h - final_h) // 2 | |
| right = left + final_w | |
| bottom = top + final_h | |
| img = img.crop((left, top, right, bottom)) | |
| return img | |
| def switch_lora_on_gpu(pipe, target_mode): | |
| print(f"🔄 Switching LoRA to [{target_mode}]...") | |
| pipe.unload_lora_weights() | |
| if target_mode == "deblur": | |
| pipe.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring") | |
| pipe.set_adapters(["deblurring"]) | |
| elif target_mode == "bokeh": | |
| pipe.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh") | |
| pipe.set_adapters(["bokeh"]) | |
| def preprocess_input_image(raw_img): | |
| """ | |
| Always enforces resizing to 512 (long edge) and cropping to 16x. | |
| """ | |
| if raw_img is None: return None, None | |
| print(f"🔄 Preprocessing Input... Enforcing Resize.") | |
| # Always resize and crop | |
| final_input = resize_and_crop_to_16(raw_img) | |
| return final_input, final_input | |
| def draw_red_dot_on_preview(clean_img, evt: gr.SelectData): | |
| if clean_img is None: return None, None | |
| img_copy = clean_img.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| x, y = evt.index | |
| r = 8 | |
| draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2) | |
| draw.line((x-r, y, x+r, y), fill="red", width=2) | |
| draw.line((x, y-r, x, y+r), fill="red", width=2) | |
| return img_copy, evt.index | |
| # ========================================== | |
| # 4. Main Pipeline | |
| # ========================================== | |
| def run_genfocus_pipeline(clean_input, click_coords, K_value): | |
| global pipe_flux, depth_model, depth_transform | |
| device = "cuda" | |
| if clean_input is None: | |
| raise gr.Error("Please complete Step 1 (Upload Image) first.") | |
| W_dyn, H_dyn = clean_input.size | |
| print(f"📏 Processing Image Size: {W_dyn}x{H_dyn}") | |
| if pipe_flux is None: | |
| print("🚀 Loading FLUX to GPU (First Run)...") | |
| from Genfocus.pipeline.flux import FluxPipeline | |
| pipe_flux = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| token=os.getenv("HF_TOKEN") | |
| ).to(device) | |
| else: | |
| try: | |
| _ = pipe_flux.device.type | |
| pipe_flux.to(device) | |
| except Exception: | |
| print("⚠️ GPU Context changed, reloading FLUX...") | |
| from Genfocus.pipeline.flux import FluxPipeline | |
| pipe_flux = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| token=os.getenv("HF_TOKEN") | |
| ).to(device) | |
| # --- Load Depth Pro --- | |
| depth_loader = DepthProLoader() | |
| if depth_model is None: | |
| depth_model, depth_transform = depth_loader.load(device=device) | |
| else: | |
| try: | |
| depth_model = depth_model.to(device) | |
| except Exception: | |
| print("⚠️ GPU Context changed, reloading Depth Pro...") | |
| depth_model, depth_transform = depth_loader.load(device=device) | |
| from Genfocus.pipeline.flux import Condition, generate, seed_everything | |
| print("⚡ Running Inference...") | |
| # STAGE 1: DEBLUR | |
| switch_lora_on_gpu(pipe_flux, "deblur") | |
| condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0)) | |
| cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0) | |
| cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0) | |
| seed_everything(42) | |
| deblurred_img = generate( | |
| pipe_flux, height=H_dyn, width=W_dyn, | |
| prompt="a sharp photo with everything in focus", | |
| conditions=[cond0, cond1] | |
| ).images[0] | |
| if K_value == 0: | |
| return deblurred_img | |
| # STAGE 2: BOKEH | |
| if click_coords is None: | |
| click_coords = [W_dyn // 2, H_dyn // 2] | |
| # Depth Estimation | |
| img_t = depth_transform(deblurred_img).to(device) | |
| with torch.no_grad(): | |
| pred = depth_model.infer(img_t, f_px=None) | |
| depth_map = pred["depth"].cpu().numpy().squeeze() | |
| safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max) | |
| disp_orig = 1.0 / safe_depth | |
| # Resize disp to match current image dimensions | |
| disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR) | |
| # Defocus Map | |
| tx, ty = click_coords | |
| tx = min(max(int(tx), 0), W_dyn - 1) | |
| ty = min(max(int(ty), 0), H_dyn - 1) | |
| disp_focus = float(disp[ty, tx]) | |
| dmf = disp - np.float32(disp_focus) | |
| defocus_abs = np.abs(K_value * dmf) | |
| MAX_COC = 100.0 | |
| defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float() | |
| cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0) | |
| # Generate New Latents | |
| seed_everything(42) | |
| gen = torch.Generator(device=pipe_flux.device).manual_seed(1234) | |
| current_latents, _ = pipe_flux.prepare_latents( | |
| batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn, | |
| dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None | |
| ) | |
| # Generate Bokeh | |
| switch_lora_on_gpu(pipe_flux, "bokeh") | |
| cond_img = Condition(deblurred_img, "bokeh") | |
| cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True) | |
| seed_everything(42) | |
| gen = torch.Generator(device=pipe_flux.device).manual_seed(1234) | |
| with torch.no_grad(): | |
| res = generate( | |
| pipe_flux, height=H_dyn, width=W_dyn, | |
| prompt="an excellent photo with a large aperture", | |
| conditions=[cond_img, cond_dmf], | |
| guidance_scale=1.0, kv_cache=False, generator=gen, | |
| latents=current_latents, | |
| ) | |
| generated_bokeh = res.images[0] | |
| return generated_bokeh | |
| # ========================================== | |
| # 5. UI Setup | |
| # ========================================== | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 1400px; } | |
| #output_image { min-height: 400px; } | |
| """ | |
| base_path = os.getcwd() | |
| example_dir = os.path.join(base_path, "example") | |
| valid_examples = [] | |
| if os.path.exists(example_dir): | |
| files = os.listdir(example_dir) | |
| for f in files: | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png')): | |
| valid_examples.append([os.path.join(example_dir, f)]) | |
| with gr.Blocks(css=css) as demo: | |
| clean_processed_state = gr.State(value=None) | |
| click_coords_state = gr.State(value=None) | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# 📷 Genfocus Pipeline: Interactive Refocusing (HF Demo)") | |
| # --- Description & Guide --- | |
| gr.Markdown(""" | |
| ### 📖 User Guide | |
| **Generative Refocusing** supports two main applications: | |
| * **All-In-Focus (AIF) Estimation:** Set **K = 0**. The model will restore the AIF image from the blurry input. | |
| * **Refocusing:** 1. **Click** on the subject you want to bring into focus in the **Step 2** image preview. | |
| 2. Increase **K** (Blur Strength) to generate realistic bokeh effects based on the scene's depth. | |
| > ⚠️ **Preprocessing Note:** Due to resource constraints in this demo, input images are **automatically resized** (longer edge = 512px). | |
| > If you wish to perform inference at the **original resolution**, please refer to our **[GitHub Code](https://github.com/rayray9999/Genfocus)** to run it locally. | |
| """) | |
| with gr.Row(): | |
| # --- Top Row: Inputs & Controls --- | |
| # [Step 1: Upload] | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 1: Upload Image") | |
| gr.Markdown("Click an example or upload your own image.") | |
| input_raw = gr.Image(label="Raw Input Image", type="pil") | |
| if valid_examples: | |
| gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples (Click to Load)") | |
| # [Step 2: Focus & Run] | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 2: Set Focus & K") | |
| gr.Markdown("The image below shows the actual input for the model. **Click on the image** to set the focus point.") | |
| focus_preview_img = gr.Image(label="Model Input (Processed) - Click Here", type="pil", interactive=False) | |
| with gr.Row(): | |
| click_status = gr.Textbox(label="Selected Coordinates", value="Center (Default)", interactive=False, scale=1) | |
| k_slider = gr.Slider(minimum=0, maximum=50, value=20, step=1, label="Blur Strength (K)", scale=2) | |
| run_btn = gr.Button("✨ Run Genfocus", variant="primary", scale=1) | |
| # --- Bottom Row: Output --- | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Result") | |
| output_img = gr.Image(label="Final Output", type="pil", interactive=False, elem_id="output_image") | |
| # ==================== Event Handling ==================== | |
| # 1. Update Preview (Removed resize_chk) | |
| update_trigger = [input_raw.change, input_raw.upload] | |
| for trigger in update_trigger: | |
| trigger( | |
| fn=preprocess_input_image, | |
| inputs=[input_raw], | |
| outputs=[focus_preview_img, clean_processed_state] | |
| ) | |
| # 2. Draw Red Dot on Click | |
| focus_preview_img.select( | |
| fn=draw_red_dot_on_preview, | |
| inputs=[clean_processed_state], | |
| outputs=[focus_preview_img, click_coords_state] | |
| ).then( | |
| fn=lambda x: f"x={x[0]}, y={x[1]}", | |
| inputs=[click_coords_state], | |
| outputs=[click_status] | |
| ) | |
| # 3. Run Pipeline | |
| run_btn.click( | |
| fn=run_genfocus_pipeline, | |
| inputs=[clean_processed_state, click_coords_state, k_slider], | |
| outputs=[output_img] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |