# src/training/train_svm.py import os import json import argparse import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets import numpy as np from sklearn.svm import LinearSVC from sklearn.metrics import accuracy_score import joblib def get_transforms(): return transforms.Compose([ transforms.Resize((64, 64)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), # (1, 64, 64) in [0, 1] ]) def build_datasets(data_root: str): tx = get_transforms() train_ds = datasets.OxfordIIITPet( root=data_root, split="trainval", target_types="category", transform=tx, download=True, ) test_ds = datasets.OxfordIIITPet( root=data_root, split="test", target_types="category", transform=tx, download=True, ) return train_ds, test_ds def dataset_to_numpy(dataset): """ Convert a torchvision dataset to (X, y) numpy arrays. X: (N, 4096) flattened grayscale pixels y: (N,) integer labels """ loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2) xs = [] ys = [] for images, targets in loader: # images: (B, 1, 64, 64) b = images.shape[0] images = images.view(b, -1) # (B, 4096) xs.append(images.numpy()) ys.append(targets.numpy()) X = np.concatenate(xs, axis=0) y = np.concatenate(ys, axis=0) return X, y def ensure_labels_json(train_ds, labels_path: str): os.makedirs(os.path.dirname(labels_path), exist_ok=True) if os.path.exists(labels_path): with open(labels_path, "r") as f: labels = json.load(f) # sanity: if it already exists, just return return labels # OxfordIIITPet: category targets are indices into .categories id_to_name = {i: name for i, name in enumerate(train_ds.categories)} with open(labels_path, "w") as f: json.dump(id_to_name, f, indent=2) return id_to_name def train_svm( data_root: str = "data/oxford-iiit-pet", ckpt_path: str = "checkpoints/svm_model.joblib", labels_path: str = "configs/labels.json", ): os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) print(f"[+] Loading datasets from {data_root} ...") train_ds, test_ds = build_datasets(data_root) print("[+] Building labels.json (if missing) ...") labels = ensure_labels_json(train_ds, labels_path) num_classes = len(labels) print(f"[+] Num classes (from labels.json): {num_classes}") print("[+] Converting train dataset to numpy features ...") X_train, y_train = dataset_to_numpy(train_ds) print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") print("[+] Converting test dataset to numpy features ...") X_test, y_test = dataset_to_numpy(test_ds) print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}") print("[+] Training Linear SVM on raw pixels ...") svm = LinearSVC( C=1.0, penalty="l2", loss="squared_hinge", max_iter=2000, # dual=True (default) is fine when n_samples > n_features, # which is the case here. ) svm.fit(X_train, y_train) print("[+] Evaluating on train and test ...") y_pred_train = svm.predict(X_train) y_pred_test = svm.predict(X_test) train_acc = accuracy_score(y_train, y_pred_train) test_acc = accuracy_score(y_test, y_pred_test) print(f" Train accuracy: {train_acc:.4f}") print(f" Test accuracy : {test_acc:.4f}") print(f"[+] Saving SVM model to {ckpt_path} ...") joblib.dump( { "model": svm, "labels_path": labels_path, "train_acc": float(train_acc), "test_acc": float(test_acc), }, ckpt_path, ) print("[+] Done.") def parse_args(): parser = argparse.ArgumentParser(description="Train Linear SVM on raw pixel features.") parser.add_argument( "--data-root", type=str, default="data/oxford-iiit-pet", help="Root directory for Oxford-IIIT Pet dataset.", ) parser.add_argument( "--ckpt-path", type=str, default="checkpoints/svm_model.joblib", help="Where to save the trained SVM model.", ) parser.add_argument( "--labels-path", type=str, default="configs/labels.json", help="Path to labels mapping JSON.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() train_svm( data_root=args.data_root, ckpt_path=args.ckpt_path, labels_path=args.labels_path, )