# src/training/extract_resnet_features.py import os import argparse import numpy as np import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision.models import resnet18, ResNet18_Weights def build_datasets(data_root: str, preprocess): """ Build Oxford-IIIT Pet train/test datasets with ResNet preprocessing. """ train_ds = datasets.OxfordIIITPet( root=data_root, split="trainval", target_types="category", transform=preprocess, download=True, ) test_ds = datasets.OxfordIIITPet( root=data_root, split="test", target_types="category", transform=preprocess, download=True, ) return train_ds, test_ds def build_dataloaders(train_ds, test_ds, batch_size: int = 64, num_workers: int = 2): train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=False, # don't shuffle, we just want deterministic feature arrays num_workers=num_workers, ) test_loader = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, ) return train_loader, test_loader def build_resnet18_backbone(device: torch.device): """ Load ResNet18 pretrained on ImageNet, replace final fc with Identity. Returns: model (nn.Module), feature_dim (int), preprocess (transform) """ weights = ResNet18_Weights.DEFAULT model = resnet18(weights=weights) feature_dim = model.fc.in_features # 512 # Replace final classifier with identity to get penultimate features import torch.nn as nn model.fc = nn.Identity() model.to(device) model.eval() # Official preprocessing pipeline for these weights (resize + crop + norm) preprocess = weights.transforms() return model, feature_dim, preprocess def extract_features(model, loader, device: torch.device): """ Run images through the model and collect features + labels. Returns: X: (N, feature_dim) numpy array y: (N,) numpy array """ features_list = [] labels_list = [] with torch.no_grad(): for images, targets in loader: images = images.to(device) outputs = model(images) # (B, feature_dim) features_list.append(outputs.cpu().numpy()) labels_list.append(targets.numpy()) X = np.concatenate(features_list, axis=0) y = np.concatenate(labels_list, axis=0) return X, y def main( data_root: str = "data/oxford-iiit-pet", out_dir: str = "data/resnet18_features", batch_size: int = 64, num_workers: int = 2, ): os.makedirs(out_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[+] Using device: {device}") print("[+] Building ResNet18 backbone and preprocessing ...") model, feature_dim, preprocess = build_resnet18_backbone(device) print(f"[+] Feature dimension: {feature_dim}") print(f"[+] Loading Oxford-IIIT Pet from {data_root} ...") train_ds, test_ds = build_datasets(data_root, preprocess) print("[+] Building dataloaders ...") train_loader, test_loader = build_dataloaders( train_ds, test_ds, batch_size=batch_size, num_workers=num_workers ) print("[+] Extracting train features ...") X_train, y_train = extract_features(model, train_loader, device) print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") print("[+] Extracting test features ...") X_test, y_test = extract_features(model, test_loader, device) print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}") # Save to .npy x_train_path = os.path.join(out_dir, "X_train_resnet18.npy") y_train_path = os.path.join(out_dir, "y_train.npy") x_test_path = os.path.join(out_dir, "X_test_resnet18.npy") y_test_path = os.path.join(out_dir, "y_test.npy") print(f"[+] Saving features to {out_dir} ...") np.save(x_train_path, X_train) np.save(y_train_path, y_train) np.save(x_test_path, X_test) np.save(y_test_path, y_test) print("[+] Done extracting ResNet18 features.") def parse_args(): parser = argparse.ArgumentParser( description="Extract ResNet18 (pretrained) features for Oxford-IIIT Pet." ) parser.add_argument( "--data-root", type=str, default="data/oxford-iiit-pet", help="Root directory for Oxford-IIIT Pet dataset.", ) parser.add_argument( "--out-dir", type=str, default="data/resnet18_features", help="Directory to save .npy feature files.", ) parser.add_argument( "--batch-size", type=int, default=64, help="Batch size for feature extraction.", ) parser.add_argument( "--num-workers", type=int, default=2, help="Num workers for dataloader.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() main( data_root=args.data_root, out_dir=args.out_dir, batch_size=args.batch_size, num_workers=args.num_workers, )