thoughtlink / src /stage1_binary.py
Srilekha23's picture
Upload folder using huggingface_hub
832948a verified
"""Phase 4: Stage 1 binary classifier -- Rest vs Active."""
import os
import numpy as np
import joblib
from pathlib import Path
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import cross_val_score, GroupKFold
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODELS_DIR = PROJECT_ROOT / "models"
RESULTS_DIR = PROJECT_ROOT / "results"
def to_binary(label_str):
"""Map label string to binary: 0=Rest, 1=Active."""
return 0 if label_str == "Relax" else 1
def train_stage1():
"""Train and evaluate Stage 1 binary classifier."""
MODELS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
# Load features
data = np.load(str(PROJECT_ROOT / "features.npz"), allow_pickle=True)
X = data["X"]
y_str = data["y"]
subjects = data["subjects"]
# Binary labels
y_binary = np.array([to_binary(s) for s in y_str])
print(f"Total samples: {len(y_binary)}")
print(f" Rest (0): {np.sum(y_binary == 0)}")
print(f" Active (1): {np.sum(y_binary == 1)}")
# Cross-subject split: hold out ~20% of subjects
unique_subjects = sorted(set(subjects))
n_test = max(1, len(unique_subjects) // 5)
test_subjects = set(unique_subjects[-n_test:])
train_subjects = set(unique_subjects[:-n_test])
print(f"\nSubject split:")
print(f" Train subjects ({len(train_subjects)}): {sorted(train_subjects)}")
print(f" Test subjects ({len(test_subjects)}): {sorted(test_subjects)}")
train_mask = np.array([s in train_subjects for s in subjects])
test_mask = np.array([s in test_subjects for s in subjects])
X_train, y_train = X[train_mask], y_binary[train_mask]
X_test, y_test = X[test_mask], y_binary[test_mask]
subjects_train = subjects[train_mask]
print(f" Train samples: {len(y_train)} (Rest: {np.sum(y_train==0)}, Active: {np.sum(y_train==1)})")
print(f" Test samples: {len(y_test)} (Rest: {np.sum(y_test==0)}, Active: {np.sum(y_test==1)})")
# Verify no subject leakage
assert len(set(subjects[train_mask]) & set(subjects[test_mask])) == 0, "Subject leakage!"
# Models to try
models = {
"SVM_RBF": Pipeline([
("scaler", StandardScaler()),
("clf", SVC(kernel="rbf", probability=True, class_weight="balanced", random_state=42))
]),
"RandomForest": Pipeline([
("scaler", StandardScaler()),
("clf", RandomForestClassifier(n_estimators=200, class_weight="balanced", random_state=42))
]),
"GradientBoosting": Pipeline([
("scaler", StandardScaler()),
("clf", GradientBoostingClassifier(n_estimators=200, random_state=42))
]),
}
# 5-fold GroupKFold cross-validation on training set
print("\n--- Cross-Validation (5-fold GroupKFold) ---")
gkf = GroupKFold(n_splits=min(5, len(train_subjects)))
cv_results = {}
for name, model in models.items():
scores = cross_val_score(
model, X_train, y_train, cv=gkf, groups=subjects_train, scoring="accuracy"
)
cv_results[name] = scores
print(f" {name}: {scores.mean():.3f} +/- {scores.std():.3f}")
# Train on full training set and evaluate on test set
print("\n--- Test Set Evaluation ---")
test_results = {}
for name, model in models.items():
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
test_results[name] = (acc, model, y_pred)
print(f" {name}: {acc:.3f}")
# Pick best model
best_name = max(test_results, key=lambda k: test_results[k][0])
best_acc, best_model, best_pred = test_results[best_name]
print(f"\nBest model: {best_name} (accuracy: {best_acc:.3f})")
# Save best model
model_path = MODELS_DIR / "stage1_binary.pkl"
joblib.dump(best_model, str(model_path))
print(f"Saved to {model_path}")
# Classification report
report = classification_report(y_test, best_pred, target_names=["Rest", "Active"])
print(f"\nClassification Report:\n{report}")
with open(str(RESULTS_DIR / "stage1_report.txt"), "w") as f:
f.write(f"Stage 1 Binary Classifier: {best_name}\n")
f.write(f"Test Accuracy: {best_acc:.3f}\n\n")
f.write(f"Train subjects: {sorted(train_subjects)}\n")
f.write(f"Test subjects: {sorted(test_subjects)}\n\n")
f.write(f"Cross-validation results:\n")
for name, scores in cv_results.items():
f.write(f" {name}: {scores.mean():.3f} +/- {scores.std():.3f}\n")
f.write(f"\nTest set results:\n")
for name, (acc, _, _) in test_results.items():
f.write(f" {name}: {acc:.3f}\n")
f.write(f"\nBest: {best_name}\n\n")
f.write(f"Classification Report:\n{report}\n")
# Confusion matrix plot
cm = confusion_matrix(y_test, best_pred)
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=["Rest", "Active"],
yticklabels=["Rest", "Active"], ax=ax)
ax.set_title(f"Stage 1: {best_name} (Acc: {best_acc:.3f})")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
plt.savefig(str(RESULTS_DIR / "stage1_confusion.png"), dpi=150)
plt.close()
print(f"Confusion matrix saved to results/stage1_confusion.png")
# Save the train/test subject split for consistency with Stage 2
np.savez(
str(PROJECT_ROOT / "subject_split.npz"),
train_subjects=np.array(sorted(train_subjects)),
test_subjects=np.array(sorted(test_subjects)),
)
return best_model, best_acc
if __name__ == "__main__":
train_stage1()