bidr-relight / src /main.py
maxhuber's picture
Upload 14 files
3336231 verified
import numpy as np
import torch
from skimage.io import imread, imsave
import matplotlib.pyplot as plt
import os
import logging
from src.models.mock import MockISDModel
from src.models.unet import ResNet50UNet
from src.clustering import cluster_log_chromaticity
from src.image_util import (
resize_with_same_aspect,
linear_to_log,
)
from src.bidr_util import (
project_to_log_chromaticity_plane,
get_global_isd,
rotation_matrix_from_vectors,
)
from src.plotting import (
plot_img_rgb_logrgb,
plot_content_log_chroma,
plot_plane,
calculate_shared_limits,
plane_view_from_normal,
plot_log_chroma_plane_pre_clustering,
plot_log_chroma_plane_post_clustering,
plot_cluster_spatial_distribution,
plot_transformed_img_logrgb,
)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
def relight_content_image(
content_path,
style_path,
isd_model,
isd_model_path,
output_path,
resize_scale=1 / 4,
clustering_method="greedy",
bin_radius=1.0,
n_clusters=4,
shading_only=False,
compression_factor=0.7,
view_isd=False,
length_scale=1.0,
log_transl=None,
rot_percent=100.0, # you either use rot_percent or rot_angle
rot_angle=None,
always_use_global_illum_norm=True,
):
"""
Vectorized relighting pipeline using ISDs and optional illuminant transfer.
Parameters
----------
content_path : str
Path to the content image.
style_path : str
Path to the style image.
isd_model : callable
Model that takes a content image tensor and returns ISD map: shape (1, 3, H, W)
isd_model_path : str
Path to isd model weights
output_path : str
Path to save the transformed image.
resize_scale : float
Scale each image by this factor maintaining same aspect ratio.
e.g., resize_scale = 1/2 means downsample by 2. Useful for low RAM,
as model alone consumes >25GB.
bin_radius: float
pixels are clustered into bins of `bin_radius` size in log chroma plane.
shading_only : bool, default False
If True, only compress along the ISD without changing illuminant color.
compression_factor : float, default 0.7
Factor to compress intensity along the ISD.
"""
CONTENT = 0
STYLE = 1
if isd_model == "unet":
model = ResNet50UNet(
in_channels=3,
out_channels=3,
pretrained=True,
checkpoint=isd_model_path,
se_block=True,
dropout=0.0,
)
elif isd_model == "vit":
# TODO
pass
else:
model = MockISDModel()
model.eval()
# --- 1. Load and preprocess images ---
img_paths = [content_path, style_path]
imgs = []
imgs_bit_depth = []
log_imgs = []
log_norm_imgs = []
for i in range(len(img_paths)):
img = imread(img_paths[i])
img_bit_depth = np.iinfo(img.dtype).bits
img = resize_with_same_aspect(img, scale=resize_scale)
# Drop alpha if present
img = img[:, :, :3]
# Convert to log RGB and normalize to unit range
log_img = linear_to_log(img)
log_norm_img = log_img / np.log(2**img_bit_depth - 1)
log_norm_img = log_norm_img.astype(np.float32)
imgs.append(img)
imgs_bit_depth.append(img_bit_depth)
log_imgs.append(log_img)
log_norm_imgs.append(log_norm_img)
# --- 2. Use pretrained ISD estimator to get ISD maps ---
isd_maps = []
for log_norm_img in log_norm_imgs:
# Estimate ISD map
log_norm_img_tensor = (
torch.from_numpy(log_norm_img).permute(2, 0, 1).unsqueeze(0)
)
isd_map = model(log_norm_img_tensor)
# Convert back to np.array
isd_map = isd_map.detach().squeeze(0).numpy() # (3, H, W)
isd_map = np.transpose(isd_map, (1, 2, 0)) # (H, W, 3)
# Normalize output to unit vector
isd_norm = np.linalg.norm(isd_map, axis=2, keepdims=True)
isd_norm[isd_norm == 0] = 1
isd_map = isd_map / isd_norm
isd_maps.append(isd_map)
# --- 3. Segment pixels by material: We group pixels whose projections are close in the 2D log-chromaticity plane (the plane orthogonal to the ISD). ---
plane_offset = np.array((10.4, 10.4, 10.4))
log_chroma_content = project_to_log_chromaticity_plane(
log_imgs[CONTENT],
isd_maps[CONTENT],
plane_offset=plane_offset,
use_average_isd=False,
) # (H, W, 3)
# Visualize before clustering
plot_log_chroma_plane_pre_clustering(
log_chroma_content, isd_maps[CONTENT], imgs[CONTENT], imgs_bit_depth[CONTENT]
)
# Perform clustering
bin_masks, bin_map = cluster_log_chromaticity(
log_chroma_content,
method=clustering_method,
bin_radius=bin_radius,
n_clusters=n_clusters,
)
# Visualize after clustering
plot_log_chroma_plane_post_clustering(
log_chroma_content,
isd_maps[CONTENT],
bin_masks,
bin_radius if clustering_method == "greedy" else None,
)
plot_cluster_spatial_distribution(bin_masks, imgs[CONTENT], imgs_bit_depth[CONTENT])
# --- 4. Find the global illumination vector of content image. ---
# Details:
# Under the assumption of uniform spectral ratio (i.e., same ambient and direct),
# each material's vector between fully lit and fully dark in log RGB will have the same direction (ISD)
# and same norm. We will denote this as "illumination vector" (referred as N in BIDR paper) and estimate this
# as the rightmost mode of length distribution.
# Compute signed dist along isd for each pixel.
diff_vec = log_imgs[CONTENT] - log_chroma_content
signed_dist_map = (diff_vec * isd_maps[CONTENT]).sum(
axis=2
) # dot product into (H,W)
# Get the 5th and 95th percentile of signed dist distribution for each bin of pixels.
lengths = []
for bin_mask in bin_masks:
signed_dists = signed_dist_map[bin_mask].ravel()
p5 = np.percentile(signed_dists, 5)
p95 = np.percentile(signed_dists, 95)
length = p95 - p5
lengths.append(length)
# Create a histogram from this length array
bin_counts, bin_edges = np.histogram(np.array(lengths))
bin_x = 0.5 * (bin_edges[:-1] + bin_edges[1:]) # Use center as their position
# Extract peaks/modes from this histogram.
# Modes are defined as those histogram bins with relatively high counts.
# The count threshold is dynamically set to 30% of max count.
count_threshold = 0.3 * bin_counts.max()
mode_counts = bin_counts[bin_counts > count_threshold]
mode_x = bin_x[bin_counts > count_threshold]
# Use the rightmost mode as the illum vector norm.
illum_vector_norm = mode_x[-1]
logger.info(f"Estimated illumination vector norm {illum_vector_norm}")
# --- 5. Estimate fully (dark, bright) pairs for each material. ---
# Identify clusters with only lit or only shaded pixels and estimate missing points.
# First, compute the global range (95th - 5th percentile) for the whole image
global_signed_dists = signed_dist_map.ravel()
global_p5 = np.percentile(global_signed_dists, 5)
global_p95 = np.percentile(global_signed_dists, 95)
global_range = global_p95 - global_p5
global_median = np.percentile(global_signed_dists, 50)
# For each cluster, determine if it's fully lit, fully shaded, or mixed
dark_points = []
bright_points = []
for bin_idx, bin_mask in enumerate(bin_masks):
bin_isd = isd_maps[CONTENT][bin_mask].mean(axis=0)
bin_isd = bin_isd / np.linalg.norm(bin_isd)
length = lengths[bin_idx]
signed_dists_bin = signed_dist_map[bin_mask].ravel()
p5 = np.percentile(signed_dists_bin, 5)
p95 = np.percentile(signed_dists_bin, 95)
bin_indices = np.array(np.where(bin_mask)).T
p5_idx = np.argmin(np.abs(signed_dists_bin - p5))
p95_idx = np.argmin(np.abs(signed_dists_bin - p95))
p5_point = log_imgs[CONTENT][tuple(bin_indices[p5_idx])]
p95_point = log_imgs[CONTENT][tuple(bin_indices[p95_idx])]
if always_use_global_illum_norm:
is_degenerate = True
else:
is_degenerate = length < 0.3 * global_range
if is_degenerate:
median_dist = np.median(signed_dists_bin)
if median_dist > global_median:
# Fully lit: use real p95 as bright, estimate dark
bright_point = p95_point
dark_point = bright_point - illum_vector_norm * bin_isd
else:
# Fully dark: use real p5 as dark, estimate bright
dark_point = p5_point
bright_point = dark_point + illum_vector_norm * bin_isd
else:
# Mixed: use real p5/p95 as endpoints
dark_point = p5_point
bright_point = p95_point
dark_points.append(dark_point)
bright_points.append(bright_point)
dark_points = np.array(dark_points)
bright_points = np.array(bright_points)
logger.info(
f"Estimated dark and bright points for {len(bin_masks)} material clusters"
)
# print("Dark points: ", dark_points)
# print("Bright points: ", bright_points)
# --- 6. Pivot each material around their dark point from content ISD to the average style ISD. ---
# For each cylinder, we rotate its pixels about the cylinder's dark point from content ISD to style ISD.
# By default, this pure rotation maintains length of px from their corresponding dark point.
# If a proportional `length_scale` is provided (not =1.0), we rotate + scale.
global_style_isd = get_global_isd(isd_maps[STYLE])
global_content_isd = get_global_isd(isd_maps[CONTENT])
tf_log_content = np.copy(log_imgs[CONTENT])
# Compute rotation matrix that rotates content ISD to style ISD,
R = rotation_matrix_from_vectors(
global_content_isd,
global_style_isd,
rot_percent=rot_percent,
rot_angle=rot_angle,
)
logger.info(
f"Average Style ISD: {global_style_isd}. Average Content ISD: {global_content_isd}"
)
for cyl_idx, cyl_mask in enumerate(bin_masks):
# Get cylinder's (dark,bright) pair
cyl_dark_point = dark_points[cyl_idx]
# cyl_bright_point = bright_points[cyl_idx]
# Iterate through pixels that belongs to this cluster to apply the transformation
cyl_px_idx = np.where(cyl_mask.ravel())[0]
for px_idx in cyl_px_idx:
h, w = np.unravel_index(px_idx, cyl_mask.shape)
log_px = log_imgs[CONTENT][h, w]
# Rotate (with optional linear scaling)
rel = log_px - cyl_dark_point
transformed_log_px = cyl_dark_point + length_scale * R @ rel
tf_log_content[h, w] = transformed_log_px
logger.info("Pivoted all pixels for each material cluster.")
# --- 7. Optional global translation in log RGB for all pixels to change ambient illuminant.---
if log_transl is not None:
tf_log_content = tf_log_content + log_transl
# --- 8. Plots: log chroma, illum norm distribution, sRGB, logRGB. ---
# Prepare data for plotting
content_img, style_img = imgs
content_bit_depth, style_bit_depth = imgs_bit_depth
norm_content_img = content_img / (2**content_bit_depth - 1)
norm_style_img = style_img / (2**style_bit_depth - 1)
log_content_img, log_style_img = log_imgs
log_chroma_normal = get_global_isd(isd_maps[CONTENT])
log_chroma_offset = plane_offset
# Compute bounds/xyz limits for log rgb.
# Useful to see projections correctness when all log RGB plots share same limits.
log_chroma_content_flat = log_chroma_content.reshape(-1, 3)
log_content_flat = log_content_img.reshape(-1, 3)
log_style_flat = log_style_img.reshape(-1, 3)
tf_log_content_flat = tf_log_content.reshape(-1, 3)
bounds = calculate_shared_limits(
[
log_style_flat,
log_content_flat,
log_chroma_content_flat,
tf_log_content_flat,
],
padding=0.2,
)
x_limits, y_limits, z_limits = bounds
# Setting up axs
fig = plt.figure(figsize=(20, 40))
axs = dict()
axs["style_img"] = fig.add_subplot(8, 2, 1)
axs["content_img"] = fig.add_subplot(8, 2, 2)
axs["style_rgb"] = fig.add_subplot(8, 2, 3, projection="3d")
axs["content_rgb"] = fig.add_subplot(8, 2, 4, projection="3d")
axs["style_log_rgb"] = fig.add_subplot(8, 2, 5, projection="3d")
axs["content_log_rgb"] = fig.add_subplot(8, 2, 6, projection="3d")
axs["mixed_rgb"] = fig.add_subplot(8, 2, 7, projection="3d")
axs["mixed_log_rgb"] = fig.add_subplot(8, 2, 8, projection="3d")
axs["content_projected_img"] = fig.add_subplot(8, 2, 9)
axs["content_projected_log_rgb"] = fig.add_subplot(8, 2, 10, projection="3d")
axs["clustered_content_log_rgb"] = fig.add_subplot(8, 2, 11, projection="3d")
axs["tf_content_img"] = fig.add_subplot(8, 2, 13)
axs["tf_content_log_rgb"] = fig.add_subplot(8, 2, 14, projection="3d")
axs["mixed_tf_log_rgb"] = fig.add_subplot(8, 2, 15, projection="3d")
# Make log RGB plots same limits, aspect ratio
log_rgb_plots_idx = [
"style_log_rgb",
"content_log_rgb",
"mixed_log_rgb",
"content_projected_log_rgb",
"clustered_content_log_rgb",
"tf_content_log_rgb",
"mixed_tf_log_rgb",
]
for i in log_rgb_plots_idx:
axs[i].set_box_aspect([1, 1, 1])
axs[i].set_xlim(x_limits)
axs[i].set_ylim(y_limits)
axs[i].set_zlim(z_limits)
# Plots
plot_img_rgb_logrgb(
axs,
norm_content_img,
norm_style_img,
log_content_img,
log_style_img,
bin_masks,
dark_points, # Uncomment if you want to see them plotted.
bright_points,
)
plot_content_log_chroma(
axs,
log_chroma_content,
content_bit_depth,
norm_content_img,
)
plot_plane(
[axs["content_log_rgb"], axs["content_projected_log_rgb"]],
normal=log_chroma_normal,
point=log_chroma_offset,
bounds=bounds,
)
plot_transformed_img_logrgb(
axs,
tf_log_content,
log_content_img,
content_bit_depth,
)
# Make log RGB plots same view
if view_isd:
elev, azim = plane_view_from_normal(log_chroma_normal)
else:
elev = axs[log_rgb_plots_idx[-1]].elev
azim = axs[log_rgb_plots_idx[-1]].azim
for i in log_rgb_plots_idx:
axs[i].view_init(elev, azim)
plt.tight_layout()
plt.show()
# TODO (DEBUG): im only returning these for debug. remove later
return log_chroma_content, log_imgs, isd_maps, imgs