# src/training/train_resnet_pt_lr.py import os import argparse import json import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score import joblib def load_features(features_dir: str): x_train_path = os.path.join(features_dir, "X_train_resnet18.npy") y_train_path = os.path.join(features_dir, "y_train.npy") x_test_path = os.path.join(features_dir, "X_test_resnet18.npy") y_test_path = os.path.join(features_dir, "y_test.npy") assert os.path.exists(x_train_path), f"Missing: {x_train_path}" assert os.path.exists(y_train_path), f"Missing: {y_train_path}" assert os.path.exists(x_test_path), f"Missing: {x_test_path}" assert os.path.exists(y_test_path), f"Missing: {y_test_path}" X_train = np.load(x_train_path) y_train = np.load(y_train_path) X_test = np.load(x_test_path) y_test = np.load(y_test_path) return X_train, y_train, X_test, y_test def main( features_dir: str = "data/resnet18_features", ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib", labels_path: str = "configs/labels.json", ): os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) print(f"[+] Loading features from {features_dir} ...") X_train, y_train, X_test, y_test = load_features(features_dir) print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}") num_features = X_train.shape[1] print(f"[+] Feature dimension: {num_features}") # Labels mapping is not strictly needed for training, but we keep the path # around for inference later. if os.path.exists(labels_path): with open(labels_path, "r") as f: labels = json.load(f) num_classes = len(labels) print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}") else: print(f"[!] Warning: {labels_path} not found. Inference will need this later.") labels = None print("[+] Training Logistic Regression on ResNet18 features ...") clf = LogisticRegression( penalty="l2", C=1.0, solver="saga", multi_class="multinomial", max_iter=1000, n_jobs=-1, verbose=1, ) clf.fit(X_train, y_train) print("[+] Evaluating ...") y_pred_train = clf.predict(X_train) y_pred_test = clf.predict(X_test) train_acc = accuracy_score(y_train, y_pred_train) test_acc = accuracy_score(y_test, y_pred_test) print(f" Train accuracy: {train_acc:.4f}") print(f" Test accuracy : {test_acc:.4f}") print(f"[+] Saving LR head to {ckpt_path} ...") payload = { "model": clf, "backbone": "resnet18_imagenet", "feature_dim": int(num_features), "labels_path": labels_path, "train_acc": float(train_acc), "test_acc": float(test_acc), } joblib.dump(payload, ckpt_path) print("[+] Done training ResNet PT + LR.") def parse_args(): parser = argparse.ArgumentParser( description="Train Logistic Regression head on ResNet18 (pretrained) features." ) parser.add_argument( "--features-dir", type=str, default="data/resnet18_features", help="Directory containing X_train_resnet18.npy etc.", ) parser.add_argument( "--ckpt-path", type=str, default="checkpoints/resnet_pt_lr_head.joblib", help="Where to save LR head checkpoint.", ) parser.add_argument( "--labels-path", type=str, default="configs/labels.json", help="Path to labels mapping JSON.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() main( features_dir=args.features_dir, ckpt_path=args.ckpt_path, labels_path=args.labels_path, )