bidr-relight / src /illumination.py
maxhuber's picture
Upload 14 files
3336231 verified
"""Illumination vector and dark/bright point estimation."""
import numpy as np
import logging
logger = logging.getLogger(__name__)
def compute_signed_distances(log_img, log_chroma, isd_map):
"""Compute signed distance along ISD for each pixel."""
diff_vec = log_img - log_chroma
signed_dist_map = (diff_vec * isd_map).sum(axis=2)
return signed_dist_map
def estimate_cluster_lengths(signed_dist_map, bin_masks):
"""Estimate illumination length for each cluster."""
lengths = []
for bin_mask in bin_masks:
signed_dists = signed_dist_map[bin_mask].ravel()
p5 = np.percentile(signed_dists, 5)
p95 = np.percentile(signed_dists, 95)
lengths.append(p95 - p5)
return np.array(lengths)
def find_illumination_norm(lengths):
"""Find global illumination vector norm from length distribution."""
bin_counts, bin_edges = np.histogram(lengths)
bin_x = 0.5 * (bin_edges[:-1] + bin_edges[1:])
count_threshold = 0.3 * bin_counts.max()
mode_x = bin_x[bin_counts > count_threshold]
illum_norm = mode_x[-1]
logger.info(f"Estimated illumination vector norm: {illum_norm:.3f}")
return illum_norm
def estimate_dark_bright_points(log_img, isd_maps, signed_dist_map,
bin_masks, lengths, illum_norm,
always_use_global=True):
"""Estimate fully dark and bright points for each material cluster."""
global_dists = signed_dist_map.ravel()
global_range = np.percentile(global_dists, 95) - np.percentile(global_dists, 5)
global_median = np.percentile(global_dists, 50)
dark_points = []
bright_points = []
for bin_idx, bin_mask in enumerate(bin_masks):
bin_isd = isd_maps[bin_mask].mean(axis=0)
bin_isd = bin_isd / np.linalg.norm(bin_isd)
length = lengths[bin_idx]
signed_dists_bin = signed_dist_map[bin_mask].ravel()
p5 = np.percentile(signed_dists_bin, 5)
p95 = np.percentile(signed_dists_bin, 95)
bin_indices = np.array(np.where(bin_mask)).T
p5_idx = np.argmin(np.abs(signed_dists_bin - p5))
p95_idx = np.argmin(np.abs(signed_dists_bin - p95))
p5_point = log_img[tuple(bin_indices[p5_idx])]
p95_point = log_img[tuple(bin_indices[p95_idx])]
is_degenerate = always_use_global or (length < 0.3 * global_range)
if is_degenerate:
median_dist = np.median(signed_dists_bin)
if median_dist > global_median:
bright_point = p95_point
dark_point = bright_point - illum_norm * bin_isd
else:
dark_point = p5_point
bright_point = dark_point + illum_norm * bin_isd
else:
dark_point = p5_point
bright_point = p95_point
dark_points.append(dark_point)
bright_points.append(bright_point)
logger.info(f"Estimated dark/bright points for {len(bin_masks)} clusters")
return np.array(dark_points), np.array(bright_points)
def estimate_illumination(log_img, log_chroma, isd_maps, bin_masks,
always_use_global=True):
"""Complete illumination estimation pipeline."""
signed_dist_map = compute_signed_distances(log_img, log_chroma, isd_maps)
lengths = estimate_cluster_lengths(signed_dist_map, bin_masks)
illum_norm = find_illumination_norm(lengths)
dark_points, bright_points = estimate_dark_bright_points(
log_img, isd_maps, signed_dist_map, bin_masks, lengths,
illum_norm, always_use_global
)
return illum_norm, dark_points, bright_points, signed_dist_map