Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/training/train_svm.py
import os
import json
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
import joblib
def get_transforms():
return transforms.Compose([
transforms.Resize((64, 64)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(), # (1, 64, 64) in [0, 1]
])
def build_datasets(data_root: str):
tx = get_transforms()
train_ds = datasets.OxfordIIITPet(
root=data_root,
split="trainval",
target_types="category",
transform=tx,
download=True,
)
test_ds = datasets.OxfordIIITPet(
root=data_root,
split="test",
target_types="category",
transform=tx,
download=True,
)
return train_ds, test_ds
def dataset_to_numpy(dataset):
"""
Convert a torchvision dataset to (X, y) numpy arrays.
X: (N, 4096) flattened grayscale pixels
y: (N,) integer labels
"""
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
xs = []
ys = []
for images, targets in loader:
# images: (B, 1, 64, 64)
b = images.shape[0]
images = images.view(b, -1) # (B, 4096)
xs.append(images.numpy())
ys.append(targets.numpy())
X = np.concatenate(xs, axis=0)
y = np.concatenate(ys, axis=0)
return X, y
def ensure_labels_json(train_ds, labels_path: str):
os.makedirs(os.path.dirname(labels_path), exist_ok=True)
if os.path.exists(labels_path):
with open(labels_path, "r") as f:
labels = json.load(f)
# sanity: if it already exists, just return
return labels
# OxfordIIITPet: category targets are indices into .categories
id_to_name = {i: name for i, name in enumerate(train_ds.categories)}
with open(labels_path, "w") as f:
json.dump(id_to_name, f, indent=2)
return id_to_name
def train_svm(
data_root: str = "data/oxford-iiit-pet",
ckpt_path: str = "checkpoints/svm_model.joblib",
labels_path: str = "configs/labels.json",
):
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
print(f"[+] Loading datasets from {data_root} ...")
train_ds, test_ds = build_datasets(data_root)
print("[+] Building labels.json (if missing) ...")
labels = ensure_labels_json(train_ds, labels_path)
num_classes = len(labels)
print(f"[+] Num classes (from labels.json): {num_classes}")
print("[+] Converting train dataset to numpy features ...")
X_train, y_train = dataset_to_numpy(train_ds)
print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print("[+] Converting test dataset to numpy features ...")
X_test, y_test = dataset_to_numpy(test_ds)
print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
print("[+] Training Linear SVM on raw pixels ...")
svm = LinearSVC(
C=1.0,
penalty="l2",
loss="squared_hinge",
max_iter=2000,
# dual=True (default) is fine when n_samples > n_features,
# which is the case here.
)
svm.fit(X_train, y_train)
print("[+] Evaluating on train and test ...")
y_pred_train = svm.predict(X_train)
y_pred_test = svm.predict(X_test)
train_acc = accuracy_score(y_train, y_pred_train)
test_acc = accuracy_score(y_test, y_pred_test)
print(f" Train accuracy: {train_acc:.4f}")
print(f" Test accuracy : {test_acc:.4f}")
print(f"[+] Saving SVM model to {ckpt_path} ...")
joblib.dump(
{
"model": svm,
"labels_path": labels_path,
"train_acc": float(train_acc),
"test_acc": float(test_acc),
},
ckpt_path,
)
print("[+] Done.")
def parse_args():
parser = argparse.ArgumentParser(description="Train Linear SVM on raw pixel features.")
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory for Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--ckpt-path",
type=str,
default="checkpoints/svm_model.joblib",
help="Where to save the trained SVM model.",
)
parser.add_argument(
"--labels-path",
type=str,
default="configs/labels.json",
help="Path to labels mapping JSON.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
train_svm(
data_root=args.data_root,
ckpt_path=args.ckpt_path,
labels_path=args.labels_path,
)