Spaces:
Sleeping
Sleeping
| # src/training/train_svm.py | |
| import os | |
| import json | |
| import argparse | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms, datasets | |
| import numpy as np | |
| from sklearn.svm import LinearSVC | |
| from sklearn.metrics import accuracy_score | |
| import joblib | |
| def get_transforms(): | |
| return transforms.Compose([ | |
| transforms.Resize((64, 64)), | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.ToTensor(), # (1, 64, 64) in [0, 1] | |
| ]) | |
| def build_datasets(data_root: str): | |
| tx = get_transforms() | |
| train_ds = datasets.OxfordIIITPet( | |
| root=data_root, | |
| split="trainval", | |
| target_types="category", | |
| transform=tx, | |
| download=True, | |
| ) | |
| test_ds = datasets.OxfordIIITPet( | |
| root=data_root, | |
| split="test", | |
| target_types="category", | |
| transform=tx, | |
| download=True, | |
| ) | |
| return train_ds, test_ds | |
| def dataset_to_numpy(dataset): | |
| """ | |
| Convert a torchvision dataset to (X, y) numpy arrays. | |
| X: (N, 4096) flattened grayscale pixels | |
| y: (N,) integer labels | |
| """ | |
| loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2) | |
| xs = [] | |
| ys = [] | |
| for images, targets in loader: | |
| # images: (B, 1, 64, 64) | |
| b = images.shape[0] | |
| images = images.view(b, -1) # (B, 4096) | |
| xs.append(images.numpy()) | |
| ys.append(targets.numpy()) | |
| X = np.concatenate(xs, axis=0) | |
| y = np.concatenate(ys, axis=0) | |
| return X, y | |
| def ensure_labels_json(train_ds, labels_path: str): | |
| os.makedirs(os.path.dirname(labels_path), exist_ok=True) | |
| if os.path.exists(labels_path): | |
| with open(labels_path, "r") as f: | |
| labels = json.load(f) | |
| # sanity: if it already exists, just return | |
| return labels | |
| # OxfordIIITPet: category targets are indices into .categories | |
| id_to_name = {i: name for i, name in enumerate(train_ds.categories)} | |
| with open(labels_path, "w") as f: | |
| json.dump(id_to_name, f, indent=2) | |
| return id_to_name | |
| def train_svm( | |
| data_root: str = "data/oxford-iiit-pet", | |
| ckpt_path: str = "checkpoints/svm_model.joblib", | |
| labels_path: str = "configs/labels.json", | |
| ): | |
| os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) | |
| print(f"[+] Loading datasets from {data_root} ...") | |
| train_ds, test_ds = build_datasets(data_root) | |
| print("[+] Building labels.json (if missing) ...") | |
| labels = ensure_labels_json(train_ds, labels_path) | |
| num_classes = len(labels) | |
| print(f"[+] Num classes (from labels.json): {num_classes}") | |
| print("[+] Converting train dataset to numpy features ...") | |
| X_train, y_train = dataset_to_numpy(train_ds) | |
| print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") | |
| print("[+] Converting test dataset to numpy features ...") | |
| X_test, y_test = dataset_to_numpy(test_ds) | |
| print(f" X_test shape: {X_test.shape}, y_test shape: {y_test.shape}") | |
| print("[+] Training Linear SVM on raw pixels ...") | |
| svm = LinearSVC( | |
| C=1.0, | |
| penalty="l2", | |
| loss="squared_hinge", | |
| max_iter=2000, | |
| # dual=True (default) is fine when n_samples > n_features, | |
| # which is the case here. | |
| ) | |
| svm.fit(X_train, y_train) | |
| print("[+] Evaluating on train and test ...") | |
| y_pred_train = svm.predict(X_train) | |
| y_pred_test = svm.predict(X_test) | |
| train_acc = accuracy_score(y_train, y_pred_train) | |
| test_acc = accuracy_score(y_test, y_pred_test) | |
| print(f" Train accuracy: {train_acc:.4f}") | |
| print(f" Test accuracy : {test_acc:.4f}") | |
| print(f"[+] Saving SVM model to {ckpt_path} ...") | |
| joblib.dump( | |
| { | |
| "model": svm, | |
| "labels_path": labels_path, | |
| "train_acc": float(train_acc), | |
| "test_acc": float(test_acc), | |
| }, | |
| ckpt_path, | |
| ) | |
| print("[+] Done.") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train Linear SVM on raw pixel features.") | |
| parser.add_argument( | |
| "--data-root", | |
| type=str, | |
| default="data/oxford-iiit-pet", | |
| help="Root directory for Oxford-IIIT Pet dataset.", | |
| ) | |
| parser.add_argument( | |
| "--ckpt-path", | |
| type=str, | |
| default="checkpoints/svm_model.joblib", | |
| help="Where to save the trained SVM model.", | |
| ) | |
| parser.add_argument( | |
| "--labels-path", | |
| type=str, | |
| default="configs/labels.json", | |
| help="Path to labels mapping JSON.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| train_svm( | |
| data_root=args.data_root, | |
| ckpt_path=args.ckpt_path, | |
| labels_path=args.labels_path, | |
| ) | |