File size: 9,198 Bytes
ea234dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import torch
import copy
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from typing import Union, Tuple
from torch.utils.data import Subset, Dataset

# --- Rich Imports ---
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel
from rich.text import Text
from rich.syntax import Syntax

console = Console()

def print_transform_summary(
    name: str, geometric_sync: v2.Compose, haze_only: v2.Compose, common: v2.Compose
):
    """
    Prints a structured summary of the different transformation components using Rich.
    """
    # Create the root tree
    tree = Tree(f"[bold cyan]{name}[/]")

    # 1. Geometric Branch
    geo_branch = tree.add("[bold magenta]1. Geometric (Synchronous)[/]")
    geo_branch.add("[dim]Applies identically to CLEAR & HAZY for alignment[/]")
    geo_branch.add(str(geometric_sync))

    # 2. Appearance Branch
    haze_branch = tree.add("[bold yellow]2. Appearance (Hazy-Only)[/]")
    haze_branch.add("[dim]Simulates real-world haze variations[/]")
    haze_branch.add(str(haze_only))

    # 3. Common Branch
    common_branch = tree.add("[bold green]3. Common (Tensor & Norm)[/]")
    common_branch.add("[dim]Final prep: ToTensor, Normalize[/]")
    common_branch.add(str(common))

    # Print in a nice panel
    console.print(Panel(tree, title="[bold]Augmentation Pipeline[/]", expand=False, border_style="blue"))


def restandardize_tensor(
    tensor: torch.Tensor,
    mean: Union[torch.Tensor, Tuple[float, float, float]] = [0.5, 0.5, 0.5],
    std: Union[torch.Tensor, Tuple[float, float, float]] = [0.5, 0.5, 0.5],
) -> torch.Tensor:
    """
    Reverses normalization (z-score) -> (Tensor * STD) + MEAN.
    Returns tensor clipped to [0, 1].
    """
    if not isinstance(mean, torch.Tensor):
        mean = torch.tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    if not isinstance(std, torch.Tensor):
        std = torch.tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)

    if tensor.dim() == 4:
        mean = mean.unsqueeze(0)
        std = std.unsqueeze(0)

    de_normalized_tensor = (tensor * std) + mean
    final_tensor = torch.clamp(de_normalized_tensor, 0.0, 1.0)
    return final_tensor


def get_haze_transforms(
    dataset_name: str,
    resize_size: int = 640,
    split: str = "train",
    verbose: bool = False,
):
    """
    Defines PyTorch v2 transformations.
    
    Training: Resizes and augments.
    Validation: DOES NOT RESIZE (keeps original resolution).
    """

    # --- 1. Define Common Blocks ---
    
    # Train: Resize + Normalize
    train_common = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])

    # Val/Test: NO RESIZE (Original Size) + Normalize
    eval_common = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])
    
    # --- 2. SANITY CHECK LOGIC (New!) ---
    if split == "sanity":
        # STRICTLY DETERMINISTIC. 
        # No RandomCrop. No Flips. No Jitter. Just Resize & Normalize.
        def sanity_transform(clear_img, hazy_img):
            clean_img = train_common(clear_img)
            hazy_img = train_common(hazy_img)
            return clean_img, hazy_img
            
        if verbose:
            print("Transform Mode: SANITY (Deterministic Resize)")
        return sanity_transform
        
    # --- 3. Training Logic ---
    if split == "train":
        geometric_sync_transforms = v2.Compose([
            v2.RandomCrop(resize_size, pad_if_needed=True),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),
        ])

        # Dataset-specific augmentation logic
        # --- REVISED SAFER TRANSFORMATIONS ---
        # We reduce the intensity significantly.
        # The goal is "Domain Randomization" (robustness), not "Data Distortion".
        
        if dataset_name == "OHAZE":
            haze_only_transforms = v2.Compose([
                v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.0),
            ])
        elif dataset_name == "DENSEHAZE":
            haze_only_transforms = v2.Compose([
                v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.0),
                v2.RandomGrayscale(p=0.2),
            ])
        elif dataset_name == "NHHAZE":
            # NH-HAZE is non-homogeneous and very small (~55 pairs).
            # We need slightly more aggressive augmentation to prevent overfitting,
            # but we must be careful not to destroy the 'patchy' haze structure.
            haze_only_transforms = v2.Compose([
                v2.ColorJitter(
                    brightness=0.1,  # Moderate brightness changes
                    contrast=0.1,    # Moderate contrast
                    saturation=0.1,  # Haze affects saturation significantly
                    hue=0.0          # Keep hue 0 to preserve realistic outdoor colors
                ),
                # Grayscale helps the model focus on structure/texture rather 
                # than memorizing the specific color cast of the few training images.
                v2.RandomGrayscale(p=0.15),
            ])
            
        elif dataset_name in ["RESIDE-INDOOR", "HAZE4K"]:
            haze_only_transforms = v2.Compose([
                v2.ColorJitter(
                    brightness=0.15, 
                    contrast=0.15, 
                    saturation=0.15, 
                    hue=0.01
                )
            ])

        # Inrease the Saturation and Hue Jitter to force the model to
        # generalize to different weather conditions/times of day 
        elif dataset_name == "RESIDE-OUTDOOR":
            haze_only_transforms = v2.Compose([
                v2.ColorJitter(
                    brightness = 0.2, 
                    contrast = 0.2,
                    saturation = 0.2, # Stronger saturation jitter
                    hue = 0.05        # Allow slight color shifting (simulates time-of-day)
                ), 
                # Optional: Occasional Grayscale forces reliance on structure, not just color
                v2.RandomGrayscale(p=0.1),
            ])
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")

        
        def haze_transform(clear_img, hazy_img):
            # Apply geometric (sync), appearance (hazy only), then common (resize+norm)
            clean_img, hazy_img = geometric_sync_transforms(clear_img, hazy_img)
            hazy_img = haze_only_transforms(hazy_img)
            clean_img = train_common(clean_img)
            hazy_img = train_common(hazy_img)
            return clean_img, hazy_img

        if verbose:
            print_transform_summary(
                f"{dataset_name} | TRAIN | {resize_size}x{resize_size}",
                geometric_sync_transforms,
                haze_only_transforms,
                train_common,
            )
        return haze_transform

    # --- 3. Validation Logic ---
    else:
        def val_transform(clear_img, hazy_img):
            # Only apply normalization, NO RESIZING
            clean_img = eval_common(clear_img)
            hazy_img = eval_common(hazy_img)
            return clean_img, hazy_img

        if verbose:
            print_transform_summary(
                f"{dataset_name} | VAL/TEST | Original Size",
                v2.Identity(),
                v2.Identity(),
                eval_common,
            )
        return val_transform


def partition_dataset(dataset, train_transform, val_transform, train_ratio=0.8):
    indices = torch.randperm(len(dataset)).tolist()
    num_train = int(len(dataset) * train_ratio)
    
    train_subset = Subset(copy.deepcopy(dataset), indices[:num_train])
    val_subset = Subset(copy.deepcopy(dataset), indices[num_train:])

    # Inject transforms
    train_subset.dataset.transform = train_transform
    val_subset.dataset.transform = val_transform

    return train_subset, val_subset


def plotting_pair_images(dataset, num_instances=3, start_index=0, save_figure=False):
    N_COLS = 2
    N_ROWS = num_instances
    end_index = start_index + num_instances
    
    fig, axes = plt.subplots(N_ROWS, N_COLS, figsize=(8, 5 * N_ROWS))
    fig.suptitle(f"GT vs Haze Comparison", fontsize=16)

    row_index = 0
    for i in range(start_index, end_index):
        clean, hazy = dataset[i]
        clean = restandardize_tensor(clean)
        hazy = restandardize_tensor(hazy)
        
        axes[row_index][0].imshow(clean.permute(1, 2, 0))
        axes[row_index][0].set_title(f"Clean {i}")
        axes[row_index][0].axis("off")

        axes[row_index][1].imshow(hazy.permute(1, 2, 0))
        axes[row_index][1].set_title(f"Hazy {i}")
        axes[row_index][1].axis("off")
        row_index += 1

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    if save_figure:
        path = "hazy_clear_comparison.png"
        console.print(f"[bold green]Saving visualization to: {path}[/]")
        plt.savefig(path, dpi=300, bbox_inches="tight")

    plt.show()