FM_PhysMamba_UNET / data /ohaze.py
root
Clean upload with correct folder structure
ea234dc
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from pathlib import Path
class OHAZE_Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.dataset_dir = Path(root_dir)
self.clear_img_paths = self.dataset_dir / "GT"
self.hazy_img_paths = self.dataset_dir / "hazy"
self.clear_paths = sorted(list(self.clear_img_paths.glob("*.jpg")))
self.hazy_paths = sorted(list(self.hazy_img_paths.glob("*.jpg")))
self.data = []
for i in range(len(self.clear_paths)):
data_index = {
"index": i,
"clean": self.clear_paths[i],
"hazy": self.hazy_paths[i],
}
self.data.append(data_index)
self.transform = transform
def __repr__(self):
return "O-HAZE"
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_item = self.data[idx]
clean_path = data_item["clean"]
hazy_path = data_item["hazy"]
try:
clean_img = Image.open(clean_path).convert("RGB")
hazy_img = Image.open(hazy_path).convert("RGB")
except FileNotFoundError:
print(f"Error: Missing image file at {clean_path} or {hazy_path}. Skipping")
return self.__getitem__((idx + 1) % len(self))
if self.transform:
clean_img, hazy_img = self.transform(clean_img, hazy_img)
else:
clean_img = (
torch.as_tensor(np.array(clean_img)).permute(2, 0, 1).float() / 255.0
)
hazy_img = (
torch.as_tensor(np.array(hazy_img)).permute(2, 0, 1).float() / 255.0
)
return clean_img, hazy_img