bidr-relight / src /isd_estimation.py
maxhuber's picture
Upload 14 files
3336231 verified
"""ISD (Illumination Spectral Direction) estimation module."""
import numpy as np
import torch
import imageio.v3 as iio
import logging
from src.image_util import resize_with_same_aspect, linear_to_log
from src.models.mock import MockISDModel
from src.models.unet import ResNet50UNet
logger = logging.getLogger(__name__)
def get_device():
"""Get the best available device (MPS > CUDA > CPU)."""
if torch.backends.mps.is_available():
device = torch.device("mps")
logger.info("Using MPS (Apple Silicon) device")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info("Using CUDA device")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
return device
def load_and_preprocess_image(img_input, resize_scale=1.0):
"""Load image (from path) or accept numpy array and convert to log space.
img_input: either a file path (str) or a numpy array (H,W,3).
Returns:
img: Original image (H, W, 3) as numpy array (uint8/uint16)
bit_depth: Bit depth of original image
log_img: Log RGB image
log_norm_img: Log RGB normalized to [0,1]
"""
# If a numpy array was passed directly, use it
if isinstance(img_input, np.ndarray):
img = img_input.copy()
# Determine bit depth from dtype
if np.issubdtype(img.dtype, np.integer):
bit_depth = np.iinfo(img.dtype).bits
else:
# Float arrays - assume already in [0, 255] range for 8-bit
bit_depth = 8
if img.max() <= 1.0: # Normalized floats
img = np.clip(img * 255, 0, 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
else:
# Load from filepath using imageio (preserves bit depth better than skimage)
img = iio.imread(img_input)
# Handle different dtypes
if np.issubdtype(img.dtype, np.integer):
bit_depth = np.iinfo(img.dtype).bits
elif img.dtype == np.float32 or img.dtype == np.float64:
# Some formats store as float - check range
if img.max() <= 1.0:
# Normalized floats, scale to 16-bit
img = (img * 65535).astype(np.uint16)
bit_depth = 16
else:
# Assume 8-bit range
img = img.astype(np.uint8)
bit_depth = 8
else:
logger.warning(f"Unexpected dtype {img.dtype}, defaulting to 8-bit")
img = img.astype(np.uint8)
bit_depth = 8
logger.info(f"Loaded image: shape={img.shape}, dtype={img.dtype}, bit_depth={bit_depth}, range=[{img.min()}, {img.max()}]")
img = resize_with_same_aspect(img, scale=resize_scale)
img = img[:, :, :3] # Drop alpha if present
log_img = linear_to_log(img)
log_norm_img = log_img / np.log(2**bit_depth - 1)
log_norm_img = log_norm_img.astype(np.float32)
return img, bit_depth, log_img, log_norm_img
def get_isd_model(model_type, model_path=None, device=None):
"""Initialize ISD estimation model.
Args:
model_type: "unet", "vit", or "mock"
model_path: Path to model checkpoint (for unet/vit)
device: torch.device or None (auto-detect if None)
"""
if device is None:
device = get_device()
if model_type == "unet":
model = ResNet50UNet(
in_channels=3,
out_channels=3,
pretrained=False,
se_block=True,
dropout=0.0,
)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
elif model_type == "vit":
raise NotImplementedError("ViT model not yet implemented")
else:
model = MockISDModel()
model.eval()
model = model.to(device)
return model, device
def estimate_isd_map(log_norm_img, model, device):
"""Estimate ISD map for an image.
Args:
log_norm_img: Normalized log RGB image (H, W, 3)
model: ISD estimation model
device: torch.device to run inference on
Returns:
isd_map: Normalized ISD vectors (H, W, 3)
"""
# Convert to tensor and move to device
log_norm_img_tensor = (
torch.from_numpy(log_norm_img).permute(2, 0, 1).unsqueeze(0)
)
log_norm_img_tensor = log_norm_img_tensor.to(device)
# Run model
with torch.no_grad():
isd_map = model(log_norm_img_tensor)
# Convert back to numpy
isd_map = isd_map.cpu().detach().squeeze(0).numpy() # (3, H, W)
isd_map = np.transpose(isd_map, (1, 2, 0)) # (H, W, 3)
# Normalize to unit vectors
isd_norm = np.linalg.norm(isd_map, axis=2, keepdims=True)
isd_norm[isd_norm == 0] = 1
isd_map = isd_map / isd_norm
return isd_map
def process_image_pair(content_path, style_path, model_type="mock",
model_path=None, resize_scale=1/4, device=None):
"""Process content and style images through ISD estimation.
Args:
content_path: Path to content image or numpy array
style_path: Path to style image or numpy array
model_type: "unet", "vit", or "mock"
model_path: Path to model checkpoint
resize_scale: Scale factor for resizing
device: torch.device or None (auto-detect if None)
"""
model, device = get_isd_model(model_type, model_path, device)
results = {}
for name, path in [("content", content_path), ("style", style_path)]:
img, bit_depth, log_img, log_norm_img = load_and_preprocess_image(
path, resize_scale
)
isd_map = estimate_isd_map(log_norm_img, model, device)
results[name] = {
"img": img,
"bit_depth": bit_depth,
"log_img": log_img,
"log_norm_img": log_norm_img,
"isd_map": isd_map,
}
return results