Tejaswi Tripathi
Renderer pth pulled from S3
74f8062
Raw
History Blame Contribute Delete
13.3 kB
"""Gradio demo: fixed shallow depth-of-field rendering.
Mirrors `prototype.py` but as a Hugging Face Spaces app:
1. Depth Anything V2 (via `transformers`) estimates relative depth.
2. A pseudo CoC map (focused near the image center) is built from that depth.
3. RendererNet renders the image at a chosen f-stop, then the render is blended
back onto the original using a CoC weight, so in-focus regions (CoC ~ 0) are
left untouched.
4. A non-NN baseline (flat Gaussian blur on the background, CoC > threshold) is
produced for comparison.
The Depth Anything checkpoint is pulled from the Hub by `transformers`, so the
`external/` code is not required. RendererNet weights are loaded from
`renderer/best_renderer.pth` by default (override with RENDERER_LOCAL_PATH), or
from the Hugging Face Hub when RENDERER_REPO_ID is set.
"""
import os
import numpy as np
from PIL import Image
import matplotlib.cm as cm
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.transform import resize
from skimage.filters import gaussian
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
# ------------------
# Config (env-overridable so the Space can be configured without code edits)
# ------------------
# Depth Anything V2 checkpoint on the Hub. Swap to "...-Small-hf" for a faster
# (lower quality) model on CPU Spaces, or "...-Large-hf" for best quality.
DEPTH_MODEL_ID = os.environ.get(
"DEPTH_MODEL_ID", "depth-anything/Depth-Anything-V2-Base-hf"
)
# RendererNet weights: local file by default, or a Hugging Face model repo.
RENDERER_LOCAL_PATH = os.environ.get(
"RENDERER_LOCAL_PATH", "renderer/best_renderer.pth"
)
RENDERER_REPO_ID = os.environ.get("RENDERER_REPO_ID", "")
RENDERER_FILENAME = os.environ.get("RENDERER_FILENAME", "best_renderer.pth")
# Training-time normalization constants (must match how RendererNet was trained).
F_STOP_MAX = 22.0
FOCAL_LENGTH_MM_MAX = 200.0
COC_PX_NORM = 25.0 # CoC channel was trained as clip(coc_px, 0, 25) / 25
TARGET_SIZE = 512 # spatial size RendererNet runs at
# Pseudo CoC cap (px) -- the depth-derived CoC spans [0, COC_MAX_PX].
COC_MAX_PX = 4.0
# Cap the working resolution so CPU inference stays responsive.
MAX_SIDE = 1024
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# ------------------
# RendererNet (inlined so the Space is self-contained)
# ------------------
def double_convolution(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
class RendererNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_convolution_1 = double_convolution(in_channels, 64)
self.down_convolution_2 = double_convolution(64, 128)
self.down_convolution_3 = double_convolution(128, 256)
self.down_convolution_4 = double_convolution(256, 512)
self.down_convolution_5 = double_convolution(512, 1024)
self.up_transpose_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.up_convolution_1 = double_convolution(1024, 512)
self.up_transpose_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.up_convolution_2 = double_convolution(512, 256)
self.up_transpose_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.up_convolution_3 = double_convolution(256, 128)
self.up_transpose_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.up_convolution_4 = double_convolution(128, 64)
self.out = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
down_1 = self.down_convolution_1(x)
down_2 = self.max_pool2d(down_1)
down_3 = self.down_convolution_2(down_2)
down_4 = self.max_pool2d(down_3)
down_5 = self.down_convolution_3(down_4)
down_6 = self.max_pool2d(down_5)
down_7 = self.down_convolution_4(down_6)
down_8 = self.max_pool2d(down_7)
down_9 = self.down_convolution_5(down_8)
up_1 = self.up_transpose_1(down_9)
x = self.up_convolution_1(torch.cat([down_7, up_1], 1))
up_2 = self.up_transpose_2(x)
x = self.up_convolution_2(torch.cat([down_5, up_2], 1))
up_3 = self.up_transpose_3(x)
x = self.up_convolution_3(torch.cat([down_3, up_3], 1))
up_4 = self.up_transpose_4(x)
x = self.up_convolution_4(torch.cat([down_1, up_4], 1))
return torch.sigmoid(self.out(x))
# ------------------
# Weight loading
# ------------------
def _strip_state_dict(checkpoint):
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
return checkpoint["model_state_dict"]
return checkpoint
def load_renderer_weights(map_location):
if RENDERER_LOCAL_PATH and os.path.exists(RENDERER_LOCAL_PATH):
print(f"Loading RendererNet weights from local file {RENDERER_LOCAL_PATH}")
return _strip_state_dict(
torch.load(RENDERER_LOCAL_PATH, map_location=map_location)
)
if RENDERER_REPO_ID:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=RENDERER_REPO_ID,
filename=RENDERER_FILENAME,
token=os.environ.get("HF_TOKEN"),
)
print(f"Loading RendererNet weights from hub {RENDERER_REPO_ID}/{RENDERER_FILENAME}")
return _strip_state_dict(torch.load(path, map_location=map_location))
raise RuntimeError(
f"Could not locate RendererNet weights at {RENDERER_LOCAL_PATH!r}. "
"Set RENDERER_LOCAL_PATH or RENDERER_REPO_ID (+RENDERER_FILENAME)."
)
print("Loading RendererNet ...")
renderer_net = RendererNet(in_channels=6, out_channels=3).to(device)
renderer_net.load_state_dict(load_renderer_weights(device))
renderer_net.eval()
print(f"Loading Depth Anything V2 ({DEPTH_MODEL_ID}) ...")
depth_processor = AutoImageProcessor.from_pretrained(DEPTH_MODEL_ID)
depth_model = AutoModelForDepthEstimation.from_pretrained(DEPTH_MODEL_ID).to(device)
depth_model.eval()
print("Models ready.")
# ------------------
# Pipeline helpers
# ------------------
def fit_to_max_side(rgb, max_side):
h, w = rgb.shape[:2]
longest = max(h, w)
if longest <= max_side:
return rgb
scale = max_side / float(longest)
new_h, new_w = int(round(h * scale)), int(round(w * scale))
return resize(
rgb, (new_h, new_w), anti_aliasing=True, preserve_range=True
).astype(np.float32)
@torch.no_grad()
def predict_relative_depth(rgb):
# rgb: float [0,1] HxWx3. Returns normalized [0,1] relative depth at HxW.
h, w = rgb.shape[:2]
pil = Image.fromarray((np.clip(rgb, 0, 1) * 255).astype(np.uint8))
inputs = depth_processor(images=pil, return_tensors="pt").to(device)
depth = depth_model(**inputs).predicted_depth # [1, h', w']
depth = F.interpolate(
depth.unsqueeze(1), size=(h, w), mode="bicubic", align_corners=False
)[0, 0].cpu().numpy().astype(np.float32)
depth -= depth.min()
depth /= depth.max() + 1e-8
return depth
def pseudo_coc_px(rel_depth, focus_y, focus_x):
focus_depth = rel_depth[focus_y, focus_x]
coc = np.abs(rel_depth - focus_depth)
coc -= coc.min()
coc /= coc.max() + 1e-8
coc = coc * COC_MAX_PX
return np.clip(coc, 0, COC_MAX_PX).astype(np.float32)
def coc_px_to_norm_512(coc_px):
coc_norm = np.clip(coc_px, 0, COC_PX_NORM) / COC_PX_NORM
return resize(
coc_norm,
(TARGET_SIZE, TARGET_SIZE),
order=1,
anti_aliasing=True,
preserve_range=True,
).astype(np.float32)
def coc_blend_weight(coc_px, focus_threshold_px):
span = max(COC_MAX_PX - focus_threshold_px, 1e-6)
t = np.clip((coc_px - focus_threshold_px) / span, 0.0, 1.0)
return (t * t * (3.0 - 2.0 * t)).astype(np.float32)
def make_param_maps(size, f_stop, focal_length_mm):
fstop_map = np.ones((1, size, size), dtype=np.float32) * (f_stop / F_STOP_MAX)
focal_map = np.ones((1, size, size), dtype=np.float32) * (
focal_length_mm / FOCAL_LENGTH_MM_MAX
)
return fstop_map, focal_map
@torch.no_grad()
def run_renderer(rgb, coc_norm_512, f_stop, focal_length_mm, out_size):
rs = resize(
rgb, (TARGET_SIZE, TARGET_SIZE), anti_aliasing=True, preserve_range=True
).astype(np.float32)
chw = np.transpose(rs, (2, 0, 1))
fstop_map, focal_map = make_param_maps(TARGET_SIZE, f_stop, focal_length_mm)
coc_channel = coc_norm_512[None, :, :]
x = np.concatenate([chw, fstop_map, focal_map, coc_channel], axis=0)[None]
x = np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
x = torch.from_numpy(x).to(device)
out = renderer_net(x)[0].cpu().numpy()
out = np.clip(np.transpose(out, (1, 2, 0)), 0, 1)
return resize(
out, out_size, anti_aliasing=True, preserve_range=True
).astype(np.float32)
def colorize_coc(coc_px):
norm = np.clip(coc_px / COC_MAX_PX, 0, 1)
rgba = cm.inferno(norm)
return (rgba[:, :, :3] * 255).astype(np.uint8)
def to_uint8(rgb):
return (np.clip(rgb, 0, 1) * 255).round().astype(np.uint8)
# ------------------
# Main inference entrypoint
# ------------------
def render(
image,
f_stop,
focal_length_mm,
focus_threshold_px,
gaussian_threshold_px,
gaussian_sigma_px,
):
if image is None:
return None, None, None
rgb = np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0
rgb = fit_to_max_side(rgb, MAX_SIDE)
h, w = rgb.shape[:2]
rel_depth = predict_relative_depth(rgb)
# Focus near the image center (no point selection, like prototype.py).
focus_y, focus_x = h // 2, w // 2
coc_px = pseudo_coc_px(rel_depth, focus_y, focus_x)
# NN render blended back onto the original by CoC weight.
coc_norm_512 = coc_px_to_norm_512(coc_px)
nn_render = run_renderer(rgb, coc_norm_512, f_stop, focal_length_mm, (h, w))
weight = coc_blend_weight(coc_px, focus_threshold_px)[:, :, None]
blended = np.clip((1.0 - weight) * rgb + weight * nn_render, 0, 1)
# Non-NN baseline: flat Gaussian blur where CoC exceeds the threshold.
blurred = gaussian(
rgb, sigma=gaussian_sigma_px, channel_axis=-1, preserve_range=True
).astype(np.float32)
mask = (coc_px > gaussian_threshold_px)[:, :, None].astype(np.float32)
gaussian_render = np.clip((1.0 - mask) * rgb + mask * blurred, 0, 1)
return (
Image.fromarray(to_uint8(blended)),
Image.fromarray(to_uint8(gaussian_render)),
Image.fromarray(colorize_coc(coc_px)),
)
# ------------------
# Gradio UI
# ------------------
with gr.Blocks(title="Shallow Depth-of-Field Renderer") as demo:
gr.Markdown(
"# Physically-Based Portrait Mode Engine\n"
"Upload an image and render it at a chosen f-stop. Depth Anything V2 "
"builds a pseudo circle-of-confusion (CoC) map (focused at the image "
"center); RendererNet blurs the out-of-focus regions while in-focus "
"areas (CoC near 0) stay untouched. A flat-Gaussian baseline is shown "
"for comparison."
)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="pil", label="Input image")
f_stop = gr.Slider(0.95, 22.0, value=1.2, step=0.05, label="f-stop")
focal_length = gr.Slider(
4.0, 200.0, value=24.0, step=0.5, label="Focal length (mm)"
)
focus_threshold = gr.Slider(
0.0, COC_MAX_PX, value=0.4, step=0.05,
label="In-focus CoC threshold (px) - below this the NN is suppressed",
)
with gr.Accordion("Gaussian baseline", open=False):
gaussian_threshold = gr.Slider(
0.0, COC_MAX_PX, value=1.0, step=0.05,
label="Background CoC threshold (px)",
)
gaussian_sigma = gr.Slider(
1.0, 30.0, value=12.0, step=1.0, label="Gaussian sigma (px)"
)
run_btn = gr.Button("Render", variant="primary")
with gr.Column(scale=2):
out_render = gr.Image(label="Rendered (NN)")
with gr.Row():
out_gaussian = gr.Image(label="Gaussian background baseline")
out_coc = gr.Image(label="Pseudo CoC map")
run_btn.click(
fn=render,
inputs=[
inp,
f_stop,
focal_length,
focus_threshold,
gaussian_threshold,
gaussian_sigma,
],
outputs=[out_render, out_gaussian, out_coc],
)
_example_dir = "cache"
if os.path.isdir(_example_dir):
_examples = [
os.path.join(_example_dir, f)
for f in sorted(os.listdir(_example_dir))
if f.lower().endswith((".jpg", ".jpeg", ".png"))
][:4]
if _examples:
gr.Examples(examples=_examples, inputs=inp)
if __name__ == "__main__":
demo.launch()