Spaces:
Sleeping
Sleeping
| # src/training/train_resnet_pt_lr.py | |
| import os | |
| import argparse | |
| import json | |
| import numpy as np | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score | |
| import joblib | |
| def load_features(features_dir: str): | |
| x_train_path = os.path.join(features_dir, "X_train_resnet18.npy") | |
| y_train_path = os.path.join(features_dir, "y_train.npy") | |
| x_test_path = os.path.join(features_dir, "X_test_resnet18.npy") | |
| y_test_path = os.path.join(features_dir, "y_test.npy") | |
| assert os.path.exists(x_train_path), f"Missing: {x_train_path}" | |
| assert os.path.exists(y_train_path), f"Missing: {y_train_path}" | |
| assert os.path.exists(x_test_path), f"Missing: {x_test_path}" | |
| assert os.path.exists(y_test_path), f"Missing: {y_test_path}" | |
| X_train = np.load(x_train_path) | |
| y_train = np.load(y_train_path) | |
| X_test = np.load(x_test_path) | |
| y_test = np.load(y_test_path) | |
| return X_train, y_train, X_test, y_test | |
| def main( | |
| features_dir: str = "data/resnet18_features", | |
| ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib", | |
| labels_path: str = "configs/labels.json", | |
| ): | |
| os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) | |
| print(f"[+] Loading features from {features_dir} ...") | |
| X_train, y_train, X_test, y_test = load_features(features_dir) | |
| print(f" X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") | |
| print(f" X_test shape : {X_test.shape}, y_test shape : {y_test.shape}") | |
| num_features = X_train.shape[1] | |
| print(f"[+] Feature dimension: {num_features}") | |
| # Labels mapping is not strictly needed for training, but we keep the path | |
| # around for inference later. | |
| if os.path.exists(labels_path): | |
| with open(labels_path, "r") as f: | |
| labels = json.load(f) | |
| num_classes = len(labels) | |
| print(f"[+] Loaded labels from {labels_path}, num_classes={num_classes}") | |
| else: | |
| print(f"[!] Warning: {labels_path} not found. Inference will need this later.") | |
| labels = None | |
| print("[+] Training Logistic Regression on ResNet18 features ...") | |
| clf = LogisticRegression( | |
| penalty="l2", | |
| C=1.0, | |
| solver="saga", | |
| multi_class="multinomial", | |
| max_iter=1000, | |
| n_jobs=-1, | |
| verbose=1, | |
| ) | |
| clf.fit(X_train, y_train) | |
| print("[+] Evaluating ...") | |
| y_pred_train = clf.predict(X_train) | |
| y_pred_test = clf.predict(X_test) | |
| train_acc = accuracy_score(y_train, y_pred_train) | |
| test_acc = accuracy_score(y_test, y_pred_test) | |
| print(f" Train accuracy: {train_acc:.4f}") | |
| print(f" Test accuracy : {test_acc:.4f}") | |
| print(f"[+] Saving LR head to {ckpt_path} ...") | |
| payload = { | |
| "model": clf, | |
| "backbone": "resnet18_imagenet", | |
| "feature_dim": int(num_features), | |
| "labels_path": labels_path, | |
| "train_acc": float(train_acc), | |
| "test_acc": float(test_acc), | |
| } | |
| joblib.dump(payload, ckpt_path) | |
| print("[+] Done training ResNet PT + LR.") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Train Logistic Regression head on ResNet18 (pretrained) features." | |
| ) | |
| parser.add_argument( | |
| "--features-dir", | |
| type=str, | |
| default="data/resnet18_features", | |
| help="Directory containing X_train_resnet18.npy etc.", | |
| ) | |
| parser.add_argument( | |
| "--ckpt-path", | |
| type=str, | |
| default="checkpoints/resnet_pt_lr_head.joblib", | |
| help="Where to save LR head checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--labels-path", | |
| type=str, | |
| default="configs/labels.json", | |
| help="Path to labels mapping JSON.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main( | |
| features_dir=args.features_dir, | |
| ckpt_path=args.ckpt_path, | |
| labels_path=args.labels_path, | |
| ) | |