Machine_learning_CS-6140 / src /training /extract_resnet_features.py
Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/training/extract_resnet_features.py
import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.models import resnet18, ResNet18_Weights
def build_datasets(data_root: str, preprocess):
"""
Build Oxford-IIIT Pet train/test datasets with ResNet preprocessing.
"""
train_ds = datasets.OxfordIIITPet(
root=data_root,
split="trainval",
target_types="category",
transform=preprocess,
download=True,
)
test_ds = datasets.OxfordIIITPet(
root=data_root,
split="test",
target_types="category",
transform=preprocess,
download=True,
)
return train_ds, test_ds
def build_dataloaders(train_ds, test_ds, batch_size: int = 64, num_workers: int = 2):
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=False, # don't shuffle, we just want deterministic feature arrays
num_workers=num_workers,
)
test_loader = DataLoader(
test_ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
)
return train_loader, test_loader
def build_resnet18_backbone(device: torch.device):
"""
Load ResNet18 pretrained on ImageNet, replace final fc with Identity.
Returns:
model (nn.Module), feature_dim (int), preprocess (transform)
"""
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
feature_dim = model.fc.in_features # 512
# Replace final classifier with identity to get penultimate features
import torch.nn as nn
model.fc = nn.Identity()
model.to(device)
model.eval()
# Official preprocessing pipeline for these weights (resize + crop + norm)
preprocess = weights.transforms()
return model, feature_dim, preprocess
def extract_features(model, loader, device: torch.device):
"""
Run images through the model and collect features + labels.
Returns:
X: (N, feature_dim) numpy array
y: (N,) numpy array
"""
features_list = []
labels_list = []
with torch.no_grad():
for images, targets in loader:
images = images.to(device)
outputs = model(images) # (B, feature_dim)
features_list.append(outputs.cpu().numpy())
labels_list.append(targets.numpy())
X = np.concatenate(features_list, axis=0)
y = np.concatenate(labels_list, axis=0)
return X, y
def main(
data_root: str = "data/oxford-iiit-pet",
out_dir: str = "data/resnet18_features",
batch_size: int = 64,
num_workers: int = 2,
):
os.makedirs(out_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[+] Using device: {device}")
print("[+] Building ResNet18 backbone and preprocessing ...")
model, feature_dim, preprocess = build_resnet18_backbone(device)
print(f"[+] Feature dimension: {feature_dim}")
print(f"[+] Loading Oxford-IIIT Pet from {data_root} ...")
train_ds, test_ds = build_datasets(data_root, preprocess)
print("[+] Building dataloaders ...")
train_loader, test_loader = build_dataloaders(
train_ds, test_ds, batch_size=batch_size, num_workers=num_workers
)
print("[+] Extracting train features ...")
X_train, y_train = extract_features(model, train_loader, device)
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print("[+] Extracting test features ...")
X_test, y_test = extract_features(model, test_loader, device)
print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
# Save to .npy
x_train_path = os.path.join(out_dir, "X_train_resnet18.npy")
y_train_path = os.path.join(out_dir, "y_train.npy")
x_test_path = os.path.join(out_dir, "X_test_resnet18.npy")
y_test_path = os.path.join(out_dir, "y_test.npy")
print(f"[+] Saving features to {out_dir} ...")
np.save(x_train_path, X_train)
np.save(y_train_path, y_train)
np.save(x_test_path, X_test)
np.save(y_test_path, y_test)
print("[+] Done extracting ResNet18 features.")
def parse_args():
parser = argparse.ArgumentParser(
description="Extract ResNet18 (pretrained) features for Oxford-IIIT Pet."
)
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory for Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--out-dir",
type=str,
default="data/resnet18_features",
help="Directory to save .npy feature files.",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
help="Batch size for feature extraction.",
)
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="Num workers for dataloader.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(
data_root=args.data_root,
out_dir=args.out_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
)