FM_PhysMamba_UNET / utils.py
root
Clean upload with correct folder structure
ea234dc
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from yacs.config import CfgNode as CN
import os
from torchvision.transforms import v2
# Import your datasets and transforms
# (Ensure these imports match your project structure)
from data import (
RESIDE_Indoor,
RESIDE_Outdoor,
RESIDE_SOTS_Indoor,
RESIDE_SOTS_Outdoor,
Haze4k_Dataset,
OHAZE_Dataset,
DENSE_Haze_Dataset,
NH_Haze_Dataset
)
from data.utils import get_haze_transforms, partition_dataset
def convert_cfg_to_dict(cfg_node):
"""Recursively converts a YACS CfgNode to a standard Python dict."""
if not isinstance(cfg_node, CN):
if isinstance(cfg_node, list):
return [convert_cfg_to_dict(item) for item in cfg_node]
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_cfg_to_dict(v)
return cfg_dict
def set_seed(seed):
"""Sets the seed for reproducibility across random, numpy, and torch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_loaders_for_stage(cfg, dataset_name, resolution, batch_size, rank=0):
"""
Creates DataLoaders for a specific training stage.
Automatically handles Single-GPU vs Distributed (DDP) logic.
Args:
dataset_name (str): 'RESIDE-INDOOR', 'RESIDE-OUTDOOR', 'HAZE4K', etc.
resolution: Target size for TRAIN images (e.g., 256).
batch_size: Batch size for TRAIN images.
rank: Process rank (for printing verbose info only on rank 0).
"""
verbose = (rank == 0)
data_cfg = cfg.DATA
# Check if Distributed Processing is Initialized
is_distributed = dist.is_available() and dist.is_initialized()
# --- 1. Define Transforms ---
# Train: Resize + Augment
train_transform = get_haze_transforms(
dataset_name=dataset_name,
resize_size=resolution,
split="train",
verbose=verbose,
)
# Val: Keep Original Size + Normalize
val_transform = get_haze_transforms(
dataset_name=dataset_name,
resize_size=resolution, # Ignored for Val, but passed for API consistency
split="val",
verbose=verbose,
)
# --- 2. Instantiate Datasets Dynamically ---
train_dataset = None
val_dataset = None
if dataset_name == "RESIDE-INDOOR":
if verbose: print(f"Loading RESIDE Indoor (ITS)...")
train_dataset = RESIDE_Indoor(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, data_cfg.RESIDE_INDOOR_PATH),
transform=train_transform,
)
val_dataset = RESIDE_SOTS_Indoor(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, data_cfg.RESIDE_SOTS_PATH),
transform=val_transform,
metadata="metadata_indoor.csv",
)
elif dataset_name == "RESIDE-OUTDOOR":
if verbose: print(f"Loading RESIDE Outdoor (OTS)...")
# IMPORTANT: Ensure your OTS subset file (e.g., dense_haze.txt) is used if needed
# Modify the class init if you need to pass a specific .txt file list
train_dataset = RESIDE_Outdoor(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, "outdoor-training-set"),
transform=train_transform,
)
val_dataset = RESIDE_SOTS_Outdoor(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, "reside-sots"),
transform=val_transform,
)
elif dataset_name == "HAZE4K":
if verbose: print(f"Loading Haze4k...")
train_dataset = Haze4k_Dataset(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, "Haze4k"),
split="train",
transform=train_transform,
)
val_dataset = Haze4k_Dataset(
dataset_path=os.path.join(data_cfg.DATASET_ROOT, "Haze4k"),
split="val",
transform=val_transform,
)
elif dataset_name == "NHHAZE":
if verbose: print(f"Loading NHHAZE...")
train_dataset = NH_Haze_Dataset(
root_dir=os.path.join(data_cfg.DATASET_ROOT, "nh-haze/NH-HAZE"),
split="train",
transform=train_transform,
)
val_dataset = NH_Haze_Dataset(
root_dir=os.path.join(data_cfg.DATASET_ROOT, "nh-haze/NH-HAZE"),
split="val",
transform=val_transform,
)
elif dataset_name == "DENSEHAZE":
if verbose: print("Loading DENSE-HAZE")
densehaze_dataset = DENSE_Haze_Dataset(
os.path.join(data_cfg.DATASET_ROOT, "dense-haze"),
)
train_dataset, val_dataset = partition_dataset(
densehaze_dataset,
train_transform,
val_transform,
train_ratio = 0.91
)
else:
raise ValueError(f"Dataset {dataset_name} not supported in get_loaders_for_stage")
# --- 3. Samplers (Hybrid Logic) ---
if is_distributed:
train_sampler = DistributedSampler(train_dataset, shuffle=True)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
shuffle_train = False # Sampler handles shuffle
else:
train_sampler = None
val_sampler = None
shuffle_train = True # Loader handles shuffle
# --- 4. Loaders ---
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=shuffle_train,
sampler=train_sampler,
num_workers=cfg.NUM_WORKERS,
pin_memory=cfg.PIN_MEMORY,
drop_last=True
)
# Validation Loader: Must use batch_size=1 to handle varying image sizes
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
sampler=val_sampler,
num_workers=cfg.NUM_WORKERS,
pin_memory=cfg.PIN_MEMORY,
)
if verbose:
print(f"Data Loaders Ready. Train: {len(train_loader)} batches, Val: {len(val_loader)} images.")
return train_loader, val_loader
# --- 2. The Evaluation Loader Factory ---
def get_eval_loader(
dataset_name: str,
dataset_root: str,
resolution: int = 256,
num_workers: int = 4,
pin_memory: bool = True
):
"""
Creates a DataLoader for Evaluation (Validation/Test).
Automatically handles Single-GPU and Multi-GPU (DDP) scenarios.
Args:
dataset_name: 'RESIDE-INDOOR', 'RESIDE-OUTDOOR', 'OHAZE', 'DENSEHAZE'
dataset_root: Path to the main 'dataset' folder.
"""
# 1. Setup Transforms (Normalize only, no resize)
val_transform = get_haze_transforms(
dataset_name=dataset_name,
resize_size=resolution,
split="val",
verbose=False
)
# 2. Select Dataset
dataset = None
name_upper = dataset_name.upper()
if name_upper == "RESIDE-INDOOR":
dataset = RESIDE_SOTS_Indoor(
dataset_path=os.path.join(dataset_root, "reside-sots"),
transform=val_transform,
metadata="metadata_indoor.csv"
)
elif name_upper == "RESIDE-OUTDOOR":
dataset = RESIDE_SOTS_Outdoor(
dataset_path=os.path.join(dataset_root, "reside-sots"),
transform=val_transform
)
elif name_upper == "OHAZE":
dataset = OHAZE_Dataset(
root_dir=os.path.join(dataset_root, "o-haze/O-HAZY"),
transform=val_transform
)
elif name_upper == "DENSEHAZE":
dataset = DENSE_Haze_Dataset(
root_dir=os.path.join(dataset_root, "dense-haze"),
transform=val_transform
)
elif name_upper == "NH-HAZE":
dataset = NH_Haze_Dataset(
root_dir=os.path.join(dataset_root, "nh-haze/NH-HAZE"),
transform=val_transform,
split = "test"
)
else:
raise ValueError(f"Unknown evaluation dataset: {dataset_name}")
# 3. DDP Logic
is_distributed = dist.is_available() and dist.is_initialized()
if is_distributed:
# Splits data among GPUs so they don't evaluate the same images
sampler = DistributedSampler(dataset, shuffle=False)
else:
sampler = None
# Only print on Rank 0 to avoid console spam
if not is_distributed or (is_distributed and dist.get_rank() == 0):
print(f"[{dataset_name}] Eval Dataset loaded: {len(dataset)} images. (DDP: {is_distributed})")
# 4. Create Loader
loader = DataLoader(
dataset,
batch_size=1, # Always 1 for eval to handle different sizes
shuffle=False, # Never shuffle eval data
sampler=sampler, # Handles the DDP splitting
num_workers=num_workers,
pin_memory=pin_memory
)
return loader
# --- Padding Utilities for Inference ---
# Use these inside your evaluation loop if you encounter size mismatches
# with model architecture requirements (e.g. UNet needs div by 16)
def pad_to_multiple(image_tensor, multiple=16):
b, c, h, w = image_tensor.shape
pad_h = (multiple - (h % multiple)) % multiple
pad_w = (multiple - (w % multiple)) % multiple
padded_tensor = F.pad(image_tensor, (0, pad_w, 0, pad_h), mode="reflect")
return padded_tensor, pad_h, pad_w
def unpad(padded_tensor, pad_h, pad_w):
if pad_h == 0 and pad_w == 0:
return padded_tensor
h_padded, w_padded = padded_tensor.shape[2], padded_tensor.shape[3]
return padded_tensor[:, :, : h_padded - pad_h, : w_padded - pad_w]
def predict_large_image(solver, full_img_tensor, device, progress=None, tile_size=256, overlap_ratio=0.25, batch_size=4, nfe=10):
"""
Performs sliding-window inference with Rich progress bar integration.
"""
b, c, h, w = full_img_tensor.shape
# ... [Setup Canvas & Stride as before] ...
output_canvas = torch.zeros((1, c, h, w), device=device)
count_map = torch.zeros((1, 1, h, w), device=device)
stride = int(tile_size * (1 - overlap_ratio))
# ... [Setup Coordinates as before] ...
h_starts = list(range(0, h - tile_size + stride, stride))
w_starts = list(range(0, w - tile_size + stride, stride))
if h_starts[-1] + tile_size > h: h_starts[-1] = h - tile_size
if w_starts[-1] + tile_size > w: w_starts[-1] = w - tile_size
h_starts = sorted(list(set(h_starts)))
w_starts = sorted(list(set(w_starts)))
# ... [Weight Mask as before] ...
def get_weight_mask(size):
coords = torch.linspace(0, 1, size, device=device)
mask_1d = 1 - torch.abs(2 * coords - 1)
mask_1d = mask_1d.unsqueeze(0)
mask_2d = mask_1d.t() * mask_1d
return mask_2d.unsqueeze(0).unsqueeze(0)
weight_mask = get_weight_mask(tile_size)
# --- RICH INTEGRATION START ---
tiles = []
coords = []
all_patches = [(y, x) for y in h_starts for x in w_starts]
# Create a temporary sub-task if progress is provided
if progress is not None:
# 'transient=True' means the bar disappears when finished
tile_task_id = progress.add_task(f" └─ Tiling ({len(all_patches)} patches)", total=len(all_patches), transient=True)
for i, (y, x) in enumerate(all_patches):
# Extract Crop
crop = full_img_tensor[:, :, y:y+tile_size, x:x+tile_size].to(device)
tiles.append(crop)
coords.append((y, x))
# Inference condition
if len(tiles) == batch_size or i == len(all_patches) - 1:
batch_tensor = torch.cat(tiles, dim=0)
with torch.amp.autocast("cuda"):
prediction_batch = solver.sample(batch_tensor, nfe=nfe)
# Stitching
for j, pred_tile in enumerate(prediction_batch):
y_c, x_c = coords[j]
pred_tile = pred_tile.unsqueeze(0)
output_canvas[:, :, y_c:y_c+tile_size, x_c:x_c+tile_size] += pred_tile * weight_mask
count_map[:, :, y_c:y_c+tile_size, x_c:x_c+tile_size] += weight_mask
tiles = []
coords = []
# Update Rich Progress
if progress is not None:
progress.update(tile_task_id, advance=1)
# --- RICH INTEGRATION END ---
final_output = output_canvas / (count_map + 1e-8)
return final_output
@torch.no_grad()
def predict_large_image_vectorized(solver, full_img_tensor, device, progress=None, tile_size=256, overlap_ratio=0.25, batch_size=4, nfe=10):
"""
Vectorized sliding window inference using F.unfold/F.fold.
Much faster preparation than manual slicing loops.
"""
b, c, h, w = full_img_tensor.shape
# 1. Calculate Padding
# Unfold drops pixels if they don't fit the stride. We pad to ensure coverage.
stride = int(tile_size * (1 - overlap_ratio))
# Calculate required height/width to be divisible by stride
# Formula: (Size - Kernel) % Stride == 0
pad_h = (stride - (h - tile_size) % stride) % stride
pad_w = (stride - (w - tile_size) % stride) % stride
# Add extra padding if the image is smaller than the tile
if h < tile_size: pad_h += tile_size - h
if w < tile_size: pad_w += tile_size - w
# Pad image (Reflect padding usually best for dehazing to avoid borders)
img_padded = F.pad(full_img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
hp, wp = img_padded.shape[2], img_padded.shape[3]
# 2. Vectorized Unfold (Extract all patches at once)
# Output shape: (1, C * tile_size * tile_size, Num_Patches)
patches_raw = F.unfold(img_padded, kernel_size=tile_size, stride=stride)
# Reshape to (Num_Patches, C, tile_size, tile_size) for the model
# 1. Transpose -> (1, Num_Patches, Flattened_Patch)
# 2. View -> (Num_Patches, C, tile_size, tile_size)
num_patches = patches_raw.shape[2]
patches_raw = patches_raw.transpose(1, 2).view(num_patches, c, tile_size, tile_size)
# 3. Create Weight Mask (for smooth blending)
# We create one mask and repeat it for all patches
def get_weight_mask(size):
coords = torch.linspace(0, 1, size, device=device)
mask_1d = 1 - torch.abs(2 * coords - 1)
mask_2d = mask_1d.unsqueeze(0).t() * mask_1d.unsqueeze(0)
return mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
weight_patch = get_weight_mask(tile_size) # (1, 1, 256, 256)
# 4. Batched Inference Loop
# We can't process ALL patches at once (OOM), so we chunk them by 'batch_size'
pred_patches_list = []
# Rich Progress Setup
task_id = None
if progress is not None:
task_id = progress.add_task(f" └─ Vectorized ({num_patches} patches)", total=num_patches, transient=True)
for i in range(0, num_patches, batch_size):
# Select batch
chunk = patches_raw[i : i + batch_size].to(device)
with torch.amp.autocast("cuda"):
# Model Output: (Batch, C, 256, 256)
pred_chunk = solver.sample(chunk, nfe=nfe)
# Apply weighting *immediately* to save memory later
# (Batch, C, 256, 256) * (1, 1, 256, 256)
pred_chunk_weighted = pred_chunk * weight_patch
# Flatten back to (Batch, C*H*W) for folding
pred_chunk_flat = pred_chunk_weighted.view(pred_chunk.shape[0], -1)
pred_patches_list.append(pred_chunk_flat)
if progress is not None:
progress.update(task_id, advance=chunk.shape[0])
# Concatenate all processed patches: (Num_Patches, Flattened_Pixels)
pred_patches_all = torch.cat(pred_patches_list, dim=0)
# 5. Fold (Stitch back together)
# Reshape to (1, Flattened_Pixels, Num_Patches) for F.fold
pred_patches_all = pred_patches_all.t().unsqueeze(0)
# Fold creates the summed image
output_sum = F.fold(
pred_patches_all,
output_size=(hp, wp),
kernel_size=tile_size,
stride=stride
)
# 6. Normalize Weights (Account for overlaps)
# We do the same "Unfold -> Fold" process for the weights to know what to divide by
ones_patch = torch.ones(1, 1, tile_size, tile_size, device=device) * weight_patch
ones_flat = ones_patch.view(1, -1).repeat(num_patches, 1).t().unsqueeze(0) # Repeat weight for every patch
weight_sum = F.fold(
ones_flat,
output_size=(hp, wp),
kernel_size=tile_size,
stride=stride
)
# 7. Final Normalize & Crop
final_img = output_sum / (weight_sum + 1e-8)
# Crop back to original size (remove padding)
final_img = final_img[:, :, :h, :w]
return final_img
def preprocess_single_image(image, device="cuda"):
"""
Preprocesses a single raw image (PIL Image or Numpy array) for inference.
Steps:
1. Converts to Tensor (v2.ToImage)
2. Scales to [0, 1] float32 (v2.ToDtype)
3. Normalizes using mean=0.5, std=0.5 (Matches training)
4. Adds Batch Dimension (1, C, H, W)
5. Moves to Device
"""
# Exact same logic as 'eval_common' in get_haze_transforms
transform = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
# Apply transform: (C, H, W)
img_tensor = transform(image)
# Add batch dimension: (1, C, H, W)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor.to(device)