import torch import torch.nn as nn from torchvision.models import resnet18, ResNet18_Weights from torch.utils.data import DataLoader import numpy as np from core_dataset import CoreDataset from config import BATCH_SIZE class FeatureExtractor: def __init__(self, device=None): self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self._load_model() def _load_model(self): """Load pretrained ResNet18 and remove classification layer""" weights = ResNet18_Weights.DEFAULT model = resnet18(weights=weights) # Remove the final classification layer model = nn.Sequential(*list(model.children())[:-1]) model = model.to(self.device) model.eval() return model def extract_features(self, image_dir): """Extract features from all images in directory""" dataset = CoreDataset(image_dir) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) features = [] image_paths = [] print("Extracting features from images...") with torch.no_grad(): for batch, paths in dataloader: batch = batch.to(self.device) batch_features = self.model(batch) batch_features = batch_features.view(batch_features.size(0), -1) features.append(batch_features.cpu().numpy()) image_paths.extend(paths) features = np.vstack(features) print(f"Extracted features shape: {features.shape}") return features, image_paths if __name__ == "__main__": from config import IMAGE_DIR extractor = FeatureExtractor() features, paths = extractor.extract_features(IMAGE_DIR) print(f"Extracted features for {len(paths)} images")