FM_PhysMamba_UNET / data /nhhaze.py
root
Clean upload with correct folder structure
ea234dc
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from pathlib import Path
import os
class NH_Haze_Dataset(Dataset):
def __init__(self, root_dir, transform = None, split = "train"):
self.dataset_dir = Path(root_dir) / split
self.clear_img_paths = self.dataset_dir / "clear"
self.hazy_img_paths = self.dataset_dir / "hazy"
self.clear_paths = sorted(list(self.clear_img_paths.glob("*.png")))
self.hazy_paths = set(p.name for p in self.hazy_img_paths.glob("*.png"))
self.data = []
for i in range(len(self.clear_paths)):
file_stem = self.clear_paths[i].stem
id_image = file_stem.split("_")[0]
hazy_filename = f"{id_image}_hazy.png"
if hazy_filename not in self.hazy_paths:
print(f"Warning: The id {hazy_path} does not have the hazy image")
continue
hazy_path = self.hazy_img_paths / hazy_filename
data_index = {
"index": i,
"clean": self.clear_paths[i],
"hazy": hazy_path,
}
self.data.append(data_index)
self.transform = transform
def __repr__(self):
return "NH-HAZE"
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_item = self.data[idx]
clean_path = data_item["clean"]
hazy_path = data_item["hazy"]
try:
clean_img = Image.open(clean_path).convert("RGB")
hazy_img = Image.open(hazy_path).convert("RGB")
except FileNotFoundError:
print(f"Error cannot load the file at {clean_path} and {hazy_path}")
return self.__getitem__((idx + 1) % len(self))
if self.transform:
clean_img, hazy_img = self.transform(clean_img, hazy_img)
else:
clean_img = (
torch.as_tensor(np.array(clean_img)).permute(2, 0, 1).float() / 255.0
)
hazy_img = (
torch.as_tensor(np.array(hazy_img)).permute(2, 0, 1).float() / 255.0
)
return clean_img, hazy_img