Genfocus-Demo / app.py
Ray
update github link
a058003
import os
import cv2
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import spaces
from huggingface_hub import hf_hub_download
# ==========================================
# 1. Global Settings & Variables
# ==========================================
MODEL_ID = "black-forest-labs/FLUX.1-dev"
DEBLUR_LORA_PATH = "."
DEBLUR_WEIGHT_NAME = "deblurNet.safetensors"
BOKEH_LORA_DIR = "."
BOKEH_WEIGHT_NAME = "bokehNet.safetensors"
# Global variables
pipe_flux = None
depth_model = None
depth_transform = None
# ==========================================
# 2. Depth Pro Loader
# ==========================================
class DepthProLoader:
def load(self, device):
print("🔄 Loading Depth Pro model...")
try:
global Condition, generate, seed_everything, FluxPipeline, depth_pro
from Genfocus.pipeline.flux import Condition, generate, seed_everything, FluxPipeline
import depth_pro
from depth_pro.depth_pro import DEFAULT_MONODEPTH_CONFIG_DICT
import copy
WEIGHTS_REPO_ID = "nycu-cplab/Genfocus-Model"
DEPTH_FILENAME = "checkpoints/depth_pro.pt"
checkpoint_path = hf_hub_download(
repo_id=WEIGHTS_REPO_ID,
filename=DEPTH_FILENAME,
repo_type="model"
)
cfg = copy.deepcopy(DEFAULT_MONODEPTH_CONFIG_DICT)
cfg.checkpoint_uri = checkpoint_path
try:
create_fn = depth_pro.create_model_and_transforms
except AttributeError:
from depth_pro.depth_pro import create_model_and_transforms
create_fn = create_model_and_transforms
model, transform = create_fn(
config=cfg,
device=device,
precision=torch.float32
)
model.eval()
print(f"✅ Depth Pro loaded on {device}.")
return model, transform
except Exception as e:
print(f"❌ Failed to load Depth Pro: {e}")
raise e
# ==========================================
# 3. Helper Functions
# ==========================================
def resize_and_crop_to_16(img: Image.Image) -> Image.Image:
"""
1. Resize the longer side to 512, maintaining aspect ratio.
2. Crop the dimensions to be multiples of 16.
"""
w, h = img.size
target = 512
# 1. Resize longer side to 512
if w >= h:
scale = target / w
else:
scale = target / h
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
# 2. Crop to multiples of 16
final_w = (new_w // 16) * 16
final_h = (new_h // 16) * 16
# Center crop calculation
left = (new_w - final_w) // 2
top = (new_h - final_h) // 2
right = left + final_w
bottom = top + final_h
img = img.crop((left, top, right, bottom))
return img
def switch_lora_on_gpu(pipe, target_mode):
print(f"🔄 Switching LoRA to [{target_mode}]...")
pipe.unload_lora_weights()
if target_mode == "deblur":
pipe.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring")
pipe.set_adapters(["deblurring"])
elif target_mode == "bokeh":
pipe.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh")
pipe.set_adapters(["bokeh"])
def preprocess_input_image(raw_img):
"""
Always enforces resizing to 512 (long edge) and cropping to 16x.
"""
if raw_img is None: return None, None
print(f"🔄 Preprocessing Input... Enforcing Resize.")
# Always resize and crop
final_input = resize_and_crop_to_16(raw_img)
return final_input, final_input
def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
if clean_img is None: return None, None
img_copy = clean_img.copy()
draw = ImageDraw.Draw(img_copy)
x, y = evt.index
r = 8
draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2)
draw.line((x-r, y, x+r, y), fill="red", width=2)
draw.line((x, y-r, x, y+r), fill="red", width=2)
return img_copy, evt.index
# ==========================================
# 4. Main Pipeline
# ==========================================
@spaces.GPU(duration=120)
def run_genfocus_pipeline(clean_input, click_coords, K_value):
global pipe_flux, depth_model, depth_transform
device = "cuda"
if clean_input is None:
raise gr.Error("Please complete Step 1 (Upload Image) first.")
W_dyn, H_dyn = clean_input.size
print(f"📏 Processing Image Size: {W_dyn}x{H_dyn}")
if pipe_flux is None:
print("🚀 Loading FLUX to GPU (First Run)...")
from Genfocus.pipeline.flux import FluxPipeline
pipe_flux = FluxPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=os.getenv("HF_TOKEN")
).to(device)
else:
try:
_ = pipe_flux.device.type
pipe_flux.to(device)
except Exception:
print("⚠️ GPU Context changed, reloading FLUX...")
from Genfocus.pipeline.flux import FluxPipeline
pipe_flux = FluxPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=os.getenv("HF_TOKEN")
).to(device)
# --- Load Depth Pro ---
depth_loader = DepthProLoader()
if depth_model is None:
depth_model, depth_transform = depth_loader.load(device=device)
else:
try:
depth_model = depth_model.to(device)
except Exception:
print("⚠️ GPU Context changed, reloading Depth Pro...")
depth_model, depth_transform = depth_loader.load(device=device)
from Genfocus.pipeline.flux import Condition, generate, seed_everything
print("⚡ Running Inference...")
# STAGE 1: DEBLUR
switch_lora_on_gpu(pipe_flux, "deblur")
condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0))
cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0)
seed_everything(42)
deblurred_img = generate(
pipe_flux, height=H_dyn, width=W_dyn,
prompt="a sharp photo with everything in focus",
conditions=[cond0, cond1]
).images[0]
if K_value == 0:
return deblurred_img
# STAGE 2: BOKEH
if click_coords is None:
click_coords = [W_dyn // 2, H_dyn // 2]
# Depth Estimation
img_t = depth_transform(deblurred_img).to(device)
with torch.no_grad():
pred = depth_model.infer(img_t, f_px=None)
depth_map = pred["depth"].cpu().numpy().squeeze()
safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
disp_orig = 1.0 / safe_depth
# Resize disp to match current image dimensions
disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR)
# Defocus Map
tx, ty = click_coords
tx = min(max(int(tx), 0), W_dyn - 1)
ty = min(max(int(ty), 0), H_dyn - 1)
disp_focus = float(disp[ty, tx])
dmf = disp - np.float32(disp_focus)
defocus_abs = np.abs(K_value * dmf)
MAX_COC = 100.0
defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0)
# Generate New Latents
seed_everything(42)
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
current_latents, _ = pipe_flux.prepare_latents(
batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn,
dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
)
# Generate Bokeh
switch_lora_on_gpu(pipe_flux, "bokeh")
cond_img = Condition(deblurred_img, "bokeh")
cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True)
seed_everything(42)
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
with torch.no_grad():
res = generate(
pipe_flux, height=H_dyn, width=W_dyn,
prompt="an excellent photo with a large aperture",
conditions=[cond_img, cond_dmf],
guidance_scale=1.0, kv_cache=False, generator=gen,
latents=current_latents,
)
generated_bokeh = res.images[0]
return generated_bokeh
# ==========================================
# 5. UI Setup
# ==========================================
css = """
#col-container { margin: 0 auto; max-width: 1400px; }
#output_image { min-height: 400px; }
"""
base_path = os.getcwd()
example_dir = os.path.join(base_path, "example")
valid_examples = []
if os.path.exists(example_dir):
files = os.listdir(example_dir)
for f in files:
if f.lower().endswith(('.jpg', '.jpeg', '.png')):
valid_examples.append([os.path.join(example_dir, f)])
with gr.Blocks(css=css) as demo:
clean_processed_state = gr.State(value=None)
click_coords_state = gr.State(value=None)
with gr.Column(elem_id="col-container"):
gr.Markdown("# 📷 Genfocus Pipeline: Interactive Refocusing (HF Demo)")
# --- Description & Guide ---
gr.Markdown("""
### 📖 User Guide
**Generative Refocusing** supports two main applications:
* **All-In-Focus (AIF) Estimation:** Set **K = 0**. The model will restore the AIF image from the blurry input.
* **Refocusing:** 1. **Click** on the subject you want to bring into focus in the **Step 2** image preview.
2. Increase **K** (Blur Strength) to generate realistic bokeh effects based on the scene's depth.
> ⚠️ **Preprocessing Note:** Due to resource constraints in this demo, input images are **automatically resized** (longer edge = 512px).
> If you wish to perform inference at the **original resolution**, please refer to our **[GitHub Code](https://github.com/rayray9999/Genfocus)** to run it locally.
""")
with gr.Row():
# --- Top Row: Inputs & Controls ---
# [Step 1: Upload]
with gr.Column(scale=1):
gr.Markdown("### Step 1: Upload Image")
gr.Markdown("Click an example or upload your own image.")
input_raw = gr.Image(label="Raw Input Image", type="pil")
if valid_examples:
gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples (Click to Load)")
# [Step 2: Focus & Run]
with gr.Column(scale=1):
gr.Markdown("### Step 2: Set Focus & K")
gr.Markdown("The image below shows the actual input for the model. **Click on the image** to set the focus point.")
focus_preview_img = gr.Image(label="Model Input (Processed) - Click Here", type="pil", interactive=False)
with gr.Row():
click_status = gr.Textbox(label="Selected Coordinates", value="Center (Default)", interactive=False, scale=1)
k_slider = gr.Slider(minimum=0, maximum=50, value=20, step=1, label="Blur Strength (K)", scale=2)
run_btn = gr.Button("✨ Run Genfocus", variant="primary", scale=1)
# --- Bottom Row: Output ---
with gr.Row():
with gr.Column():
gr.Markdown("### Result")
output_img = gr.Image(label="Final Output", type="pil", interactive=False, elem_id="output_image")
# ==================== Event Handling ====================
# 1. Update Preview (Removed resize_chk)
update_trigger = [input_raw.change, input_raw.upload]
for trigger in update_trigger:
trigger(
fn=preprocess_input_image,
inputs=[input_raw],
outputs=[focus_preview_img, clean_processed_state]
)
# 2. Draw Red Dot on Click
focus_preview_img.select(
fn=draw_red_dot_on_preview,
inputs=[clean_processed_state],
outputs=[focus_preview_img, click_coords_state]
).then(
fn=lambda x: f"x={x[0]}, y={x[1]}",
inputs=[click_coords_state],
outputs=[click_status]
)
# 3. Run Pipeline
run_btn.click(
fn=run_genfocus_pipeline,
inputs=[clean_processed_state, click_coords_state, k_slider],
outputs=[output_img]
)
if __name__ == "__main__":
demo.launch()