Spaces:
Running on Zero
Running on Zero
File size: 12,833 Bytes
4178132 52a7f9d e9e87c2 0fb92a2 4178132 c0bae76 d857d88 4178132 01e2da5 4178132 0fb92a2 4178132 d857d88 f353743 01e2da5 c0bae76 f353743 c0bae76 01e2da5 4178132 01e2da5 15a8a0f 2ed80c7 4178132 fee578b 8dbfd67 01e2da5 8dbfd67 01e2da5 8dbfd67 4178132 01e2da5 4178132 2ed80c7 4178132 f353743 4178132 26cfe11 0fb92a2 3ec12ee 26cfe11 e9e87c2 c0bae76 4178132 c0bae76 2ed80c7 4178132 c0bae76 4178132 c0bae76 4178132 f353743 26cfe11 f353743 26cfe11 01e2da5 f353743 26cfe11 f353743 26cfe11 01e2da5 0fb92a2 4178132 0fb92a2 f353743 d189df7 26cfe11 01e2da5 2ed80c7 01e2da5 4178132 26cfe11 01e2da5 f353743 01e2da5 4178132 01e2da5 d857d88 26cfe11 01e2da5 26cfe11 4178132 01e2da5 26cfe11 01e2da5 4178132 01e2da5 4178132 01e2da5 26cfe11 4178132 01e2da5 26cfe11 6cb7846 01e2da5 26cfe11 01e2da5 7bdaff8 26cfe11 01e2da5 26cfe11 01e2da5 4178132 01e2da5 26cfe11 01e2da5 7bdaff8 01e2da5 0fb92a2 26cfe11 f353743 0fb92a2 4178132 f353743 0fb92a2 4178132 f2bb007 0fb92a2 4178132 0fb92a2 4178132 0fb92a2 f353743 22ccffa a058003 f353743 0fb92a2 f353743 4178132 f353743 22ccffa f353743 4178132 f353743 4178132 f353743 4178132 f353743 4178132 f353743 4178132 f353743 22ccffa f353743 4178132 f353743 4178132 f353743 4178132 f353743 4178132 f353743 26cfe11 4178132 f353743 4178132 f2bb007 f353743 4178132 26cfe11 4178132 52a7f9d e9e87c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 | import os
import cv2
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import spaces
from huggingface_hub import hf_hub_download
# ==========================================
# 1. Global Settings & Variables
# ==========================================
MODEL_ID = "black-forest-labs/FLUX.1-dev"
DEBLUR_LORA_PATH = "."
DEBLUR_WEIGHT_NAME = "deblurNet.safetensors"
BOKEH_LORA_DIR = "."
BOKEH_WEIGHT_NAME = "bokehNet.safetensors"
# Global variables
pipe_flux = None
depth_model = None
depth_transform = None
# ==========================================
# 2. Depth Pro Loader
# ==========================================
class DepthProLoader:
def load(self, device):
print("π Loading Depth Pro model...")
try:
global Condition, generate, seed_everything, FluxPipeline, depth_pro
from Genfocus.pipeline.flux import Condition, generate, seed_everything, FluxPipeline
import depth_pro
from depth_pro.depth_pro import DEFAULT_MONODEPTH_CONFIG_DICT
import copy
WEIGHTS_REPO_ID = "nycu-cplab/Genfocus-Model"
DEPTH_FILENAME = "checkpoints/depth_pro.pt"
checkpoint_path = hf_hub_download(
repo_id=WEIGHTS_REPO_ID,
filename=DEPTH_FILENAME,
repo_type="model"
)
cfg = copy.deepcopy(DEFAULT_MONODEPTH_CONFIG_DICT)
cfg.checkpoint_uri = checkpoint_path
try:
create_fn = depth_pro.create_model_and_transforms
except AttributeError:
from depth_pro.depth_pro import create_model_and_transforms
create_fn = create_model_and_transforms
model, transform = create_fn(
config=cfg,
device=device,
precision=torch.float32
)
model.eval()
print(f"β
Depth Pro loaded on {device}.")
return model, transform
except Exception as e:
print(f"β Failed to load Depth Pro: {e}")
raise e
# ==========================================
# 3. Helper Functions
# ==========================================
def resize_and_crop_to_16(img: Image.Image) -> Image.Image:
"""
1. Resize the longer side to 512, maintaining aspect ratio.
2. Crop the dimensions to be multiples of 16.
"""
w, h = img.size
target = 512
# 1. Resize longer side to 512
if w >= h:
scale = target / w
else:
scale = target / h
new_w = int(w * scale)
new_h = int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
# 2. Crop to multiples of 16
final_w = (new_w // 16) * 16
final_h = (new_h // 16) * 16
# Center crop calculation
left = (new_w - final_w) // 2
top = (new_h - final_h) // 2
right = left + final_w
bottom = top + final_h
img = img.crop((left, top, right, bottom))
return img
def switch_lora_on_gpu(pipe, target_mode):
print(f"π Switching LoRA to [{target_mode}]...")
pipe.unload_lora_weights()
if target_mode == "deblur":
pipe.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring")
pipe.set_adapters(["deblurring"])
elif target_mode == "bokeh":
pipe.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh")
pipe.set_adapters(["bokeh"])
def preprocess_input_image(raw_img):
"""
Always enforces resizing to 512 (long edge) and cropping to 16x.
"""
if raw_img is None: return None, None
print(f"π Preprocessing Input... Enforcing Resize.")
# Always resize and crop
final_input = resize_and_crop_to_16(raw_img)
return final_input, final_input
def draw_red_dot_on_preview(clean_img, evt: gr.SelectData):
if clean_img is None: return None, None
img_copy = clean_img.copy()
draw = ImageDraw.Draw(img_copy)
x, y = evt.index
r = 8
draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2)
draw.line((x-r, y, x+r, y), fill="red", width=2)
draw.line((x, y-r, x, y+r), fill="red", width=2)
return img_copy, evt.index
# ==========================================
# 4. Main Pipeline
# ==========================================
@spaces.GPU(duration=120)
def run_genfocus_pipeline(clean_input, click_coords, K_value):
global pipe_flux, depth_model, depth_transform
device = "cuda"
if clean_input is None:
raise gr.Error("Please complete Step 1 (Upload Image) first.")
W_dyn, H_dyn = clean_input.size
print(f"π Processing Image Size: {W_dyn}x{H_dyn}")
if pipe_flux is None:
print("π Loading FLUX to GPU (First Run)...")
from Genfocus.pipeline.flux import FluxPipeline
pipe_flux = FluxPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=os.getenv("HF_TOKEN")
).to(device)
else:
try:
_ = pipe_flux.device.type
pipe_flux.to(device)
except Exception:
print("β οΈ GPU Context changed, reloading FLUX...")
from Genfocus.pipeline.flux import FluxPipeline
pipe_flux = FluxPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=os.getenv("HF_TOKEN")
).to(device)
# --- Load Depth Pro ---
depth_loader = DepthProLoader()
if depth_model is None:
depth_model, depth_transform = depth_loader.load(device=device)
else:
try:
depth_model = depth_model.to(device)
except Exception:
print("β οΈ GPU Context changed, reloading Depth Pro...")
depth_model, depth_transform = depth_loader.load(device=device)
from Genfocus.pipeline.flux import Condition, generate, seed_everything
print("β‘ Running Inference...")
# STAGE 1: DEBLUR
switch_lora_on_gpu(pipe_flux, "deblur")
condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0))
cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0)
cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0)
seed_everything(42)
deblurred_img = generate(
pipe_flux, height=H_dyn, width=W_dyn,
prompt="a sharp photo with everything in focus",
conditions=[cond0, cond1]
).images[0]
if K_value == 0:
return deblurred_img
# STAGE 2: BOKEH
if click_coords is None:
click_coords = [W_dyn // 2, H_dyn // 2]
# Depth Estimation
img_t = depth_transform(deblurred_img).to(device)
with torch.no_grad():
pred = depth_model.infer(img_t, f_px=None)
depth_map = pred["depth"].cpu().numpy().squeeze()
safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max)
disp_orig = 1.0 / safe_depth
# Resize disp to match current image dimensions
disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR)
# Defocus Map
tx, ty = click_coords
tx = min(max(int(tx), 0), W_dyn - 1)
ty = min(max(int(ty), 0), H_dyn - 1)
disp_focus = float(disp[ty, tx])
dmf = disp - np.float32(disp_focus)
defocus_abs = np.abs(K_value * dmf)
MAX_COC = 100.0
defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float()
cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0)
# Generate New Latents
seed_everything(42)
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
current_latents, _ = pipe_flux.prepare_latents(
batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn,
dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None
)
# Generate Bokeh
switch_lora_on_gpu(pipe_flux, "bokeh")
cond_img = Condition(deblurred_img, "bokeh")
cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True)
seed_everything(42)
gen = torch.Generator(device=pipe_flux.device).manual_seed(1234)
with torch.no_grad():
res = generate(
pipe_flux, height=H_dyn, width=W_dyn,
prompt="an excellent photo with a large aperture",
conditions=[cond_img, cond_dmf],
guidance_scale=1.0, kv_cache=False, generator=gen,
latents=current_latents,
)
generated_bokeh = res.images[0]
return generated_bokeh
# ==========================================
# 5. UI Setup
# ==========================================
css = """
#col-container { margin: 0 auto; max-width: 1400px; }
#output_image { min-height: 400px; }
"""
base_path = os.getcwd()
example_dir = os.path.join(base_path, "example")
valid_examples = []
if os.path.exists(example_dir):
files = os.listdir(example_dir)
for f in files:
if f.lower().endswith(('.jpg', '.jpeg', '.png')):
valid_examples.append([os.path.join(example_dir, f)])
with gr.Blocks(css=css) as demo:
clean_processed_state = gr.State(value=None)
click_coords_state = gr.State(value=None)
with gr.Column(elem_id="col-container"):
gr.Markdown("# π· Genfocus Pipeline: Interactive Refocusing (HF Demo)")
# --- Description & Guide ---
gr.Markdown("""
### π User Guide
**Generative Refocusing** supports two main applications:
* **All-In-Focus (AIF) Estimation:** Set **K = 0**. The model will restore the AIF image from the blurry input.
* **Refocusing:** 1. **Click** on the subject you want to bring into focus in the **Step 2** image preview.
2. Increase **K** (Blur Strength) to generate realistic bokeh effects based on the scene's depth.
> β οΈ **Preprocessing Note:** Due to resource constraints in this demo, input images are **automatically resized** (longer edge = 512px).
> If you wish to perform inference at the **original resolution**, please refer to our **[GitHub Code](https://github.com/rayray9999/Genfocus)** to run it locally.
""")
with gr.Row():
# --- Top Row: Inputs & Controls ---
# [Step 1: Upload]
with gr.Column(scale=1):
gr.Markdown("### Step 1: Upload Image")
gr.Markdown("Click an example or upload your own image.")
input_raw = gr.Image(label="Raw Input Image", type="pil")
if valid_examples:
gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples (Click to Load)")
# [Step 2: Focus & Run]
with gr.Column(scale=1):
gr.Markdown("### Step 2: Set Focus & K")
gr.Markdown("The image below shows the actual input for the model. **Click on the image** to set the focus point.")
focus_preview_img = gr.Image(label="Model Input (Processed) - Click Here", type="pil", interactive=False)
with gr.Row():
click_status = gr.Textbox(label="Selected Coordinates", value="Center (Default)", interactive=False, scale=1)
k_slider = gr.Slider(minimum=0, maximum=50, value=20, step=1, label="Blur Strength (K)", scale=2)
run_btn = gr.Button("β¨ Run Genfocus", variant="primary", scale=1)
# --- Bottom Row: Output ---
with gr.Row():
with gr.Column():
gr.Markdown("### Result")
output_img = gr.Image(label="Final Output", type="pil", interactive=False, elem_id="output_image")
# ==================== Event Handling ====================
# 1. Update Preview (Removed resize_chk)
update_trigger = [input_raw.change, input_raw.upload]
for trigger in update_trigger:
trigger(
fn=preprocess_input_image,
inputs=[input_raw],
outputs=[focus_preview_img, clean_processed_state]
)
# 2. Draw Red Dot on Click
focus_preview_img.select(
fn=draw_red_dot_on_preview,
inputs=[clean_processed_state],
outputs=[focus_preview_img, click_coords_state]
).then(
fn=lambda x: f"x={x[0]}, y={x[1]}",
inputs=[click_coords_state],
outputs=[click_status]
)
# 3. Run Pipeline
run_btn.click(
fn=run_genfocus_pipeline,
inputs=[clean_processed_state, click_coords_state, k_slider],
outputs=[output_img]
)
if __name__ == "__main__":
demo.launch() |