disastersense / preprocess.py
AsmitaG11's picture
Upload 7 files
fcc9242 verified
Raw
History Blame Contribute Delete
4.27 kB
"""
DisasterSense | Image Preprocessing
Transforms, dataset class, class weights, and dataloaders
for EfficientNet-B0 fine-tuning on CrisisMMD damage severity.
"""
import os
import pandas as pd
from pathlib import Path
from PIL import Image
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
IMAGE_BASE = Path("data/raw/CrisisMMD_v2.0")
PROCESSED = Path("data/processed")
LABEL_MAP = {"little_or_no_damage": 0, "mild_damage": 1, "severe_damage": 2}
NUM_CLASSES = len(LABEL_MAP)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
train_transforms = transforms.Compose([
transforms.Resize((240, 240)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
transforms.RandomRotation(degrees=10),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
eval_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
class CrisisDataset(Dataset):
def __init__(self, csv_path, image_base, transform=None):
self.df = pd.read_csv(csv_path)
self.image_base = image_base
self.transform = transform
self._drop_missing()
def _drop_missing(self):
valid = self.df["image"].apply(lambda p: (self.image_base / p).exists())
dropped = (~valid).sum()
if dropped:
print(f"Dropped {dropped} rows with missing images.")
self.df = self.df[valid].reset_index(drop=True)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img = Image.open(self.image_base / row["image"]).convert("RGB")
label = LABEL_MAP[row["label"]]
if self.transform:
img = self.transform(img)
return img, label
def compute_class_weights(csv_path):
df = pd.read_csv(csv_path)
counts = Counter(df["label"])
total = sum(counts.values())
weights = []
for label in sorted(LABEL_MAP.keys()):
w = total / (NUM_CLASSES * counts[label])
weights.append(w)
print(f" {label:25s} β†’ {counts[label]:4d} samples | weight: {w:.3f}")
return torch.tensor(weights, dtype=torch.float)
def build_dataloaders(batch_size=32):
splits = {
"train": (PROCESSED / "damage_train.csv", train_transforms),
"dev" : (PROCESSED / "damage_dev.csv", eval_transforms),
"test" : (PROCESSED / "damage_test.csv", eval_transforms),
}
loaders = {}
for split, (csv, tfm) in splits.items():
ds = CrisisDataset(csv, IMAGE_BASE, transform=tfm)
loaders[split] = DataLoader(
ds, batch_size=batch_size,
shuffle=(split == "train"),
num_workers=0,
pin_memory=torch.cuda.is_available(),
)
print(f"{split:6s} β†’ {len(ds):,} samples | {len(loaders[split])} batches")
return loaders
def verify_batch(loaders):
images, labels = next(iter(loaders["train"]))
print(f"Batch shape : {images.shape}")
print(f"Pixel range : [{images.min():.3f}, {images.max():.3f}]")
assert images.shape[1:] == (3, 224, 224)
assert -3.0 <= images.min() and images.max() <= 3.0
print("Sanity checks passed βœ“")
if __name__ == "__main__":
print("── Class Weights ─────────────────────────────────────")
weights = compute_class_weights(PROCESSED / "damage_train.csv")
print(f"\nWeights: {weights}")
print("\n── DataLoaders ───────────────────────────────────────")
loaders = build_dataloaders()
print("\n── Verification ──────────────────────────────────────")
verify_batch(loaders)
os.makedirs("models", exist_ok=True)
torch.save(weights, "models/class_weights.pt")
print("Saved β†’ models/class_weights.pt")