| | import torch |
| | import copy |
| | import matplotlib.pyplot as plt |
| | from torchvision.transforms import v2 |
| | from typing import Union, Tuple |
| | from torch.utils.data import Subset, Dataset |
| |
|
| | |
| | from rich.console import Console |
| | from rich.tree import Tree |
| | from rich.panel import Panel |
| | from rich.text import Text |
| | from rich.syntax import Syntax |
| |
|
| | console = Console() |
| |
|
| | def print_transform_summary( |
| | name: str, geometric_sync: v2.Compose, haze_only: v2.Compose, common: v2.Compose |
| | ): |
| | """ |
| | Prints a structured summary of the different transformation components using Rich. |
| | """ |
| | |
| | tree = Tree(f"[bold cyan]{name}[/]") |
| |
|
| | |
| | geo_branch = tree.add("[bold magenta]1. Geometric (Synchronous)[/]") |
| | geo_branch.add("[dim]Applies identically to CLEAR & HAZY for alignment[/]") |
| | geo_branch.add(str(geometric_sync)) |
| |
|
| | |
| | haze_branch = tree.add("[bold yellow]2. Appearance (Hazy-Only)[/]") |
| | haze_branch.add("[dim]Simulates real-world haze variations[/]") |
| | haze_branch.add(str(haze_only)) |
| |
|
| | |
| | common_branch = tree.add("[bold green]3. Common (Tensor & Norm)[/]") |
| | common_branch.add("[dim]Final prep: ToTensor, Normalize[/]") |
| | common_branch.add(str(common)) |
| |
|
| | |
| | console.print(Panel(tree, title="[bold]Augmentation Pipeline[/]", expand=False, border_style="blue")) |
| |
|
| |
|
| | def restandardize_tensor( |
| | tensor: torch.Tensor, |
| | mean: Union[torch.Tensor, Tuple[float, float, float]] = [0.5, 0.5, 0.5], |
| | std: Union[torch.Tensor, Tuple[float, float, float]] = [0.5, 0.5, 0.5], |
| | ) -> torch.Tensor: |
| | """ |
| | Reverses normalization (z-score) -> (Tensor * STD) + MEAN. |
| | Returns tensor clipped to [0, 1]. |
| | """ |
| | if not isinstance(mean, torch.Tensor): |
| | mean = torch.tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| | if not isinstance(std, torch.Tensor): |
| | std = torch.tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| |
|
| | if tensor.dim() == 4: |
| | mean = mean.unsqueeze(0) |
| | std = std.unsqueeze(0) |
| |
|
| | de_normalized_tensor = (tensor * std) + mean |
| | final_tensor = torch.clamp(de_normalized_tensor, 0.0, 1.0) |
| | return final_tensor |
| |
|
| |
|
| | def get_haze_transforms( |
| | dataset_name: str, |
| | resize_size: int = 640, |
| | split: str = "train", |
| | verbose: bool = False, |
| | ): |
| | """ |
| | Defines PyTorch v2 transformations. |
| | |
| | Training: Resizes and augments. |
| | Validation: DOES NOT RESIZE (keeps original resolution). |
| | """ |
| |
|
| | |
| | |
| | |
| | train_common = v2.Compose([ |
| | v2.ToImage(), |
| | v2.ToDtype(torch.float32, scale=True), |
| | v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
| | ]) |
| |
|
| | |
| | eval_common = v2.Compose([ |
| | v2.ToImage(), |
| | v2.ToDtype(torch.float32, scale=True), |
| | v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
| | ]) |
| | |
| | |
| | if split == "sanity": |
| | |
| | |
| | def sanity_transform(clear_img, hazy_img): |
| | clean_img = train_common(clear_img) |
| | hazy_img = train_common(hazy_img) |
| | return clean_img, hazy_img |
| | |
| | if verbose: |
| | print("Transform Mode: SANITY (Deterministic Resize)") |
| | return sanity_transform |
| | |
| | |
| | if split == "train": |
| | geometric_sync_transforms = v2.Compose([ |
| | v2.RandomCrop(resize_size, pad_if_needed=True), |
| | v2.RandomHorizontalFlip(p=0.5), |
| | v2.RandomVerticalFlip(p=0.5), |
| | ]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if dataset_name == "OHAZE": |
| | haze_only_transforms = v2.Compose([ |
| | v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.0), |
| | ]) |
| | elif dataset_name == "DENSEHAZE": |
| | haze_only_transforms = v2.Compose([ |
| | v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.0), |
| | v2.RandomGrayscale(p=0.2), |
| | ]) |
| | elif dataset_name == "NHHAZE": |
| | |
| | |
| | |
| | haze_only_transforms = v2.Compose([ |
| | v2.ColorJitter( |
| | brightness=0.1, |
| | contrast=0.1, |
| | saturation=0.1, |
| | hue=0.0 |
| | ), |
| | |
| | |
| | v2.RandomGrayscale(p=0.15), |
| | ]) |
| | |
| | elif dataset_name in ["RESIDE-INDOOR", "HAZE4K"]: |
| | haze_only_transforms = v2.Compose([ |
| | v2.ColorJitter( |
| | brightness=0.15, |
| | contrast=0.15, |
| | saturation=0.15, |
| | hue=0.01 |
| | ) |
| | ]) |
| |
|
| | |
| | |
| | elif dataset_name == "RESIDE-OUTDOOR": |
| | haze_only_transforms = v2.Compose([ |
| | v2.ColorJitter( |
| | brightness = 0.2, |
| | contrast = 0.2, |
| | saturation = 0.2, |
| | hue = 0.05 |
| | ), |
| | |
| | v2.RandomGrayscale(p=0.1), |
| | ]) |
| | else: |
| | raise ValueError(f"Unknown dataset: {dataset_name}") |
| |
|
| | |
| | def haze_transform(clear_img, hazy_img): |
| | |
| | clean_img, hazy_img = geometric_sync_transforms(clear_img, hazy_img) |
| | hazy_img = haze_only_transforms(hazy_img) |
| | clean_img = train_common(clean_img) |
| | hazy_img = train_common(hazy_img) |
| | return clean_img, hazy_img |
| |
|
| | if verbose: |
| | print_transform_summary( |
| | f"{dataset_name} | TRAIN | {resize_size}x{resize_size}", |
| | geometric_sync_transforms, |
| | haze_only_transforms, |
| | train_common, |
| | ) |
| | return haze_transform |
| |
|
| | |
| | else: |
| | def val_transform(clear_img, hazy_img): |
| | |
| | clean_img = eval_common(clear_img) |
| | hazy_img = eval_common(hazy_img) |
| | return clean_img, hazy_img |
| |
|
| | if verbose: |
| | print_transform_summary( |
| | f"{dataset_name} | VAL/TEST | Original Size", |
| | v2.Identity(), |
| | v2.Identity(), |
| | eval_common, |
| | ) |
| | return val_transform |
| |
|
| |
|
| | def partition_dataset(dataset, train_transform, val_transform, train_ratio=0.8): |
| | indices = torch.randperm(len(dataset)).tolist() |
| | num_train = int(len(dataset) * train_ratio) |
| | |
| | train_subset = Subset(copy.deepcopy(dataset), indices[:num_train]) |
| | val_subset = Subset(copy.deepcopy(dataset), indices[num_train:]) |
| |
|
| | |
| | train_subset.dataset.transform = train_transform |
| | val_subset.dataset.transform = val_transform |
| |
|
| | return train_subset, val_subset |
| |
|
| |
|
| | def plotting_pair_images(dataset, num_instances=3, start_index=0, save_figure=False): |
| | N_COLS = 2 |
| | N_ROWS = num_instances |
| | end_index = start_index + num_instances |
| | |
| | fig, axes = plt.subplots(N_ROWS, N_COLS, figsize=(8, 5 * N_ROWS)) |
| | fig.suptitle(f"GT vs Haze Comparison", fontsize=16) |
| |
|
| | row_index = 0 |
| | for i in range(start_index, end_index): |
| | clean, hazy = dataset[i] |
| | clean = restandardize_tensor(clean) |
| | hazy = restandardize_tensor(hazy) |
| | |
| | axes[row_index][0].imshow(clean.permute(1, 2, 0)) |
| | axes[row_index][0].set_title(f"Clean {i}") |
| | axes[row_index][0].axis("off") |
| |
|
| | axes[row_index][1].imshow(hazy.permute(1, 2, 0)) |
| | axes[row_index][1].set_title(f"Hazy {i}") |
| | axes[row_index][1].axis("off") |
| | row_index += 1 |
| |
|
| | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) |
| | |
| | if save_figure: |
| | path = "hazy_clear_comparison.png" |
| | console.print(f"[bold green]Saving visualization to: {path}[/]") |
| | plt.savefig(path, dpi=300, bbox_inches="tight") |
| |
|
| | plt.show() |