DamageLensAI / src /data /dataset.py
junaid17's picture
Upload 43 files
eef8873 verified
Raw
History Blame Contribute Delete
4.58 kB
import logging
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import ConvNextImageProcessor
from src.config import (
BATCH_SIZE,
NUM_WORKERS
)
from src.data.ingestion import collect_image_paths
from src.data.preprocessing import split_dataset
from src.data.augmentation import (
get_resnet_train_transforms,
get_resnet_val_transforms,
get_fusion_train_transforms,
get_fusion_val_transforms
)
logger = logging.getLogger(__name__)
class ResNetDataset(Dataset):
def __init__(self, samples, transforms=None):
self.samples = samples
self.transforms = transforms
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image_path, label = self.samples[idx]
image = Image.open(image_path).convert("RGB")
if self.transforms:
image = self.transforms(image)
return image, label
class FusionDataset(Dataset):
def __init__(
self,
samples,
transforms=None,
convnext_model_name="facebook/convnext-small-224"
):
self.samples = samples
self.transforms = transforms
logger.info("Loading ConvNeXt processor...")
self.processor = ConvNextImageProcessor.from_pretrained(
convnext_model_name
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
image_path, label = self.samples[idx]
image = Image.open(image_path).convert("RGB")
if self.transforms:
eff_tensor = self.transforms(image)
else:
raise ValueError("Fusion transforms are required.")
convnext_inputs = self.processor(
images=image,
return_tensors="pt"
)
convnext_tensor = convnext_inputs["pixel_values"].squeeze(0)
return {
"pixel_values_eff": eff_tensor,
"pixel_values_cnx": convnext_tensor,
"labels": label
}
def create_resnet_dataloaders():
logger.info("Creating ResNet dataloaders...")
samples = collect_image_paths()
train_data, val_data = split_dataset(samples)
train_dataset = ResNetDataset(
train_data,
transforms=get_resnet_train_transforms()
)
val_dataset = ResNetDataset(
val_data,
transforms=get_resnet_val_transforms()
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
logger.info("ResNet dataloaders created successfully.")
return train_loader, val_loader
def create_fusion_dataloaders():
logger.info("Creating Fusion dataloaders...")
samples = collect_image_paths()
train_data, val_data = split_dataset(samples)
train_dataset = FusionDataset(
train_data,
transforms=get_fusion_train_transforms()
)
val_dataset = FusionDataset(
val_data,
transforms=get_fusion_val_transforms()
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS
)
logger.info("Fusion dataloaders created successfully.")
return train_loader, val_loader
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
print("\nTesting ResNet dataloaders...\n")
train_loader, val_loader = create_resnet_dataloaders()
images, labels = next(iter(train_loader))
print("ResNet batch shape:", images.shape)
print("ResNet labels shape:", labels.shape)
print("\nTesting Fusion dataloaders...\n")
train_loader, val_loader = create_fusion_dataloaders()
batch = next(iter(train_loader))
print(
"Fusion EfficientNet batch shape:",
batch["pixel_values_eff"].shape
)
print(
"Fusion ConvNeXt batch shape:",
batch["pixel_values_cnx"].shape
)
print(
"Fusion labels shape:",
batch["labels"].shape
)