Spaces:
Sleeping
Sleeping
| """ | |
| DisasterSense | Image Preprocessing | |
| Transforms, dataset class, class weights, and dataloaders | |
| for EfficientNet-B0 fine-tuning on CrisisMMD damage severity. | |
| """ | |
| import os | |
| import pandas as pd | |
| from pathlib import Path | |
| from PIL import Image | |
| from collections import Counter | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| IMAGE_BASE = Path("data/raw/CrisisMMD_v2.0") | |
| PROCESSED = Path("data/processed") | |
| LABEL_MAP = {"little_or_no_damage": 0, "mild_damage": 1, "severe_damage": 2} | |
| NUM_CLASSES = len(LABEL_MAP) | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| train_transforms = transforms.Compose([ | |
| transforms.Resize((240, 240)), | |
| transforms.RandomCrop(224), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2), | |
| transforms.RandomRotation(degrees=10), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| eval_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| class CrisisDataset(Dataset): | |
| def __init__(self, csv_path, image_base, transform=None): | |
| self.df = pd.read_csv(csv_path) | |
| self.image_base = image_base | |
| self.transform = transform | |
| self._drop_missing() | |
| def _drop_missing(self): | |
| valid = self.df["image"].apply(lambda p: (self.image_base / p).exists()) | |
| dropped = (~valid).sum() | |
| if dropped: | |
| print(f"Dropped {dropped} rows with missing images.") | |
| self.df = self.df[valid].reset_index(drop=True) | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| img = Image.open(self.image_base / row["image"]).convert("RGB") | |
| label = LABEL_MAP[row["label"]] | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label | |
| def compute_class_weights(csv_path): | |
| df = pd.read_csv(csv_path) | |
| counts = Counter(df["label"]) | |
| total = sum(counts.values()) | |
| weights = [] | |
| for label in sorted(LABEL_MAP.keys()): | |
| w = total / (NUM_CLASSES * counts[label]) | |
| weights.append(w) | |
| print(f" {label:25s} β {counts[label]:4d} samples | weight: {w:.3f}") | |
| return torch.tensor(weights, dtype=torch.float) | |
| def build_dataloaders(batch_size=32): | |
| splits = { | |
| "train": (PROCESSED / "damage_train.csv", train_transforms), | |
| "dev" : (PROCESSED / "damage_dev.csv", eval_transforms), | |
| "test" : (PROCESSED / "damage_test.csv", eval_transforms), | |
| } | |
| loaders = {} | |
| for split, (csv, tfm) in splits.items(): | |
| ds = CrisisDataset(csv, IMAGE_BASE, transform=tfm) | |
| loaders[split] = DataLoader( | |
| ds, batch_size=batch_size, | |
| shuffle=(split == "train"), | |
| num_workers=0, | |
| pin_memory=torch.cuda.is_available(), | |
| ) | |
| print(f"{split:6s} β {len(ds):,} samples | {len(loaders[split])} batches") | |
| return loaders | |
| def verify_batch(loaders): | |
| images, labels = next(iter(loaders["train"])) | |
| print(f"Batch shape : {images.shape}") | |
| print(f"Pixel range : [{images.min():.3f}, {images.max():.3f}]") | |
| assert images.shape[1:] == (3, 224, 224) | |
| assert -3.0 <= images.min() and images.max() <= 3.0 | |
| print("Sanity checks passed β") | |
| if __name__ == "__main__": | |
| print("ββ Class Weights βββββββββββββββββββββββββββββββββββββ") | |
| weights = compute_class_weights(PROCESSED / "damage_train.csv") | |
| print(f"\nWeights: {weights}") | |
| print("\nββ DataLoaders βββββββββββββββββββββββββββββββββββββββ") | |
| loaders = build_dataloaders() | |
| print("\nββ Verification ββββββββββββββββββββββββββββββββββββββ") | |
| verify_batch(loaders) | |
| os.makedirs("models", exist_ok=True) | |
| torch.save(weights, "models/class_weights.pt") | |
| print("Saved β models/class_weights.pt") | |