Spaces:
Runtime error
Runtime error
| """Phase 9: Visualizations and interpretability.""" | |
| import os | |
| import numpy as np | |
| import joblib | |
| from pathlib import Path | |
| from collections import Counter | |
| from scipy.signal import welch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DATA_DIR = PROJECT_ROOT / "data" | |
| RESULTS_DIR = PROJECT_ROOT / "results" | |
| MODELS_DIR = PROJECT_ROOT / "models" | |
| CHANNEL_NAMES = ["AFF6", "AFp2", "AFp1", "AFF5", "FCz", "CPz"] | |
| FS = 500.0 | |
| # Actual label strings from the dataset | |
| LABELS_5CLASS = ["Both Fists", "Left Fist", "Relax", "Right Fist", "Tongue Tapping"] | |
| def plot_psd_per_class(): | |
| """Plot average PSD per class for all 6 channels.""" | |
| print("Generating PSD per class plots...") | |
| data = np.load(str(PROJECT_ROOT / "preprocessed_data.npz"), allow_pickle=True) | |
| X = data["X"] # (n_windows, 500, 6) | |
| y = data["y"] | |
| fig, axes = plt.subplots(1, 5, figsize=(25, 5), sharey=True) | |
| fig.suptitle("Average PSD per Class (All 6 Channels)", fontsize=14) | |
| colors = plt.cm.tab10(np.linspace(0, 1, 6)) | |
| for idx, label in enumerate(LABELS_5CLASS): | |
| ax = axes[idx] | |
| mask = y == label | |
| windows = X[mask] | |
| # Subsample if too many | |
| if len(windows) > 500: | |
| rng = np.random.RandomState(42) | |
| sel = rng.choice(len(windows), 500, replace=False) | |
| windows = windows[sel] | |
| for ch in range(6): | |
| all_psd = [] | |
| for w in windows: | |
| freqs, psd = welch(w[:, ch], fs=FS, nperseg=256) | |
| all_psd.append(psd) | |
| avg_psd = np.mean(all_psd, axis=0) | |
| ax.semilogy(freqs, avg_psd, color=colors[ch], label=CHANNEL_NAMES[ch], alpha=0.8) | |
| ax.set_title(label) | |
| ax.set_xlabel("Frequency (Hz)") | |
| ax.set_xlim(0, 50) | |
| ax.axvline(8, color="gray", linestyle="--", alpha=0.3) | |
| ax.axvline(30, color="gray", linestyle="--", alpha=0.3) | |
| if idx == 0: | |
| ax.set_ylabel("PSD (log scale)") | |
| ax.legend(fontsize=7) | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "psd_per_class.png"), dpi=150) | |
| plt.close() | |
| print(" Saved psd_per_class.png") | |
| def plot_tsne(): | |
| """t-SNE visualization of feature space colored by label.""" | |
| print("Generating t-SNE plot...") | |
| from sklearn.manifold import TSNE | |
| data = np.load(str(PROJECT_ROOT / "features.npz"), allow_pickle=True) | |
| X = data["X"] | |
| y = data["y"] | |
| # Subsample for speed | |
| n_max = 3000 | |
| if len(X) > n_max: | |
| rng = np.random.RandomState(42) | |
| idx = rng.choice(len(X), n_max, replace=False) | |
| X_sub = X[idx] | |
| y_sub = y[idx] | |
| else: | |
| X_sub = X | |
| y_sub = y | |
| tsne = TSNE(n_components=2, random_state=42, perplexity=30) | |
| X_2d = tsne.fit_transform(X_sub) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| colors = {"Both Fists": "C0", "Left Fist": "C1", "Relax": "C2", "Right Fist": "C3", "Tongue Tapping": "C4"} | |
| for label in LABELS_5CLASS: | |
| mask = y_sub == label | |
| ax.scatter(X_2d[mask, 0], X_2d[mask, 1], label=label, alpha=0.5, s=10, | |
| color=colors.get(label, "gray")) | |
| ax.set_title("t-SNE of EEG Feature Space (5 classes)") | |
| ax.legend(markerscale=3) | |
| ax.set_xlabel("t-SNE 1") | |
| ax.set_ylabel("t-SNE 2") | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "feature_tsne.png"), dpi=150) | |
| plt.close() | |
| print(" Saved feature_tsne.png") | |
| def plot_temporal_timeline(): | |
| """Plot confidence and action timeline for a single test file.""" | |
| print("Generating temporal timeline...") | |
| from preprocess import bandpass_filter, extract_active_segment, normalize_channels, segment_windows | |
| from features import extract_psd_features, extract_stat_features, extract_cross_channel_features | |
| from smoothing import MajorityVoteSmoother, ConfidenceGate, HysteresisFilter | |
| # Pick a test file | |
| test_files = sorted(DATA_DIR.glob("*.npz")) | |
| sample_file = test_files[0] | |
| arr = np.load(str(sample_file), allow_pickle=True) | |
| label_info = arr["label"].item() | |
| gt_label = label_info["label"] | |
| eeg_raw = arr["feature_eeg"] | |
| # Load models | |
| stage1 = joblib.load(str(MODELS_DIR / "stage1_binary.pkl")) | |
| stage2 = joblib.load(str(MODELS_DIR / "stage2_direction.pkl")) | |
| # Direction map | |
| DIRECTION_TO_ACTION = {0: "FORWARD", 1: "LEFT", 2: "RIGHT"} | |
| # Preprocess | |
| eeg_filtered = bandpass_filter(eeg_raw) | |
| duration = label_info["duration"] | |
| eeg_active = extract_active_segment(eeg_filtered, duration) | |
| eeg_norm = normalize_channels(eeg_active) | |
| windows = segment_windows(eeg_norm, 500, 250) | |
| # Raw predictions | |
| raw_actions = [] | |
| raw_confidences = [] | |
| smoother = MajorityVoteSmoother(5) | |
| hysteresis = HysteresisFilter(3) | |
| gate = ConfidenceGate(0.6, 0.4) | |
| smoothed_actions = [] | |
| for w in windows: | |
| features = np.concatenate([ | |
| extract_psd_features(w), | |
| extract_stat_features(w), | |
| extract_cross_channel_features(w), | |
| ]).reshape(1, -1) | |
| s1_pred = stage1.predict(features)[0] | |
| s1_proba = stage1.predict_proba(features)[0] | |
| s1_active = float(s1_proba[1]) if len(s1_proba) > 1 else float(s1_proba[0]) | |
| s2_pred = 0 | |
| s2_proba = 0.0 | |
| if s1_pred == 1: | |
| s2_pred = int(stage2.predict(features)[0]) | |
| s2_proba = float(np.max(stage2.predict_proba(features)[0])) | |
| raw_action = gate.decide(s1_active, s1_pred, s2_proba, s2_pred, DIRECTION_TO_ACTION) | |
| raw_actions.append(raw_action) | |
| conf = s2_proba if s1_pred == 1 else (1.0 - s1_active) | |
| raw_confidences.append(conf) | |
| s1 = smoother.update(raw_action) | |
| s2 = hysteresis.update(s1) | |
| smoothed_actions.append(s2) | |
| # Plot | |
| action_to_num = {"STOP": 0, "LEFT": 1, "FORWARD": 2, "RIGHT": 3} | |
| action_labels = ["STOP", "LEFT", "FORWARD", "RIGHT"] | |
| time_axis = np.arange(len(raw_actions)) * 0.5 # 0.5s per step | |
| fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 8), sharex=True) | |
| fig.suptitle(f"Temporal Timeline: {sample_file.name} (GT: {gt_label})", fontsize=13) | |
| # Confidence | |
| ax1.plot(time_axis, raw_confidences, "g-", linewidth=1.5) | |
| ax1.set_ylabel("Confidence") | |
| ax1.set_ylim(0, 1) | |
| ax1.set_title("Confidence Score") | |
| ax1.axhline(0.6, color="red", linestyle="--", alpha=0.5, label="Stage1 threshold") | |
| ax1.axhline(0.4, color="orange", linestyle="--", alpha=0.5, label="Stage2 threshold") | |
| ax1.legend(fontsize=8) | |
| # Raw actions | |
| raw_nums = [action_to_num.get(a, 0) for a in raw_actions] | |
| ax2.step(time_axis, raw_nums, "b-", where="mid", linewidth=1.5) | |
| ax2.set_yticks([0, 1, 2, 3]) | |
| ax2.set_yticklabels(action_labels) | |
| ax2.set_title("Raw Predictions") | |
| # Smoothed actions | |
| smooth_nums = [action_to_num.get(a, 0) for a in smoothed_actions] | |
| ax3.step(time_axis, smooth_nums, "r-", where="mid", linewidth=2) | |
| ax3.set_yticks([0, 1, 2, 3]) | |
| ax3.set_yticklabels(action_labels) | |
| ax3.set_title("Smoothed Predictions (MajorityVote + Hysteresis)") | |
| ax3.set_xlabel("Time (s)") | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "temporal_timeline.png"), dpi=150) | |
| plt.close() | |
| print(" Saved temporal_timeline.png") | |
| def plot_channel_importance(): | |
| """Feature importance from RandomForest model.""" | |
| print("Generating channel importance plot...") | |
| # Load Stage 1 model (RandomForest) | |
| model = joblib.load(str(MODELS_DIR / "stage1_binary.pkl")) | |
| # Get the classifier from pipeline | |
| clf = model.named_steps["clf"] | |
| if not hasattr(clf, "feature_importances_"): | |
| print(" Model doesn't have feature_importances_, skipping.") | |
| return | |
| importances = clf.feature_importances_ | |
| # Feature layout: 24 PSD + 42 Stat + 3 Cross = 69 | |
| # PSD: 4 features per channel x 6 channels = 24 | |
| # Stat: 7 features per channel x 6 channels = 42 | |
| # Cross: 3 asymmetry features | |
| psd_names = ["theta", "alpha", "beta", "a/b"] | |
| stat_names = ["var", "MAV", "RMS", "peak", "kurt", "skew", "ZC"] | |
| feature_names = [] | |
| for ch in range(6): | |
| for pn in psd_names: | |
| feature_names.append(f"{CHANNEL_NAMES[ch]}_{pn}") | |
| for ch in range(6): | |
| for sn in stat_names: | |
| feature_names.append(f"{CHANNEL_NAMES[ch]}_{sn}") | |
| feature_names.extend(["Asym_AFF5-AFF6", "Asym_AFp1-AFp2", "Diff_FCz-CPz"]) | |
| # Group by channel | |
| channel_importance = {} | |
| for ch_idx, ch_name in enumerate(CHANNEL_NAMES): | |
| # PSD features: indices ch_idx*4 to ch_idx*4+3 | |
| psd_imp = importances[ch_idx*4:ch_idx*4+4].sum() | |
| # Stat features: indices 24+ch_idx*7 to 24+ch_idx*7+6 | |
| stat_imp = importances[24+ch_idx*7:24+ch_idx*7+7].sum() | |
| channel_importance[ch_name] = psd_imp + stat_imp | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) | |
| # Per-channel total importance | |
| channels = list(channel_importance.keys()) | |
| values = list(channel_importance.values()) | |
| colors = ["#ff7f7f" if ch in ["AFF6", "AFp2"] else "#7f7fff" if ch in ["AFp1", "AFF5"] else "#7fff7f" for ch in channels] | |
| ax1.barh(channels, values, color=colors) | |
| ax1.set_title("Feature Importance by Channel (Stage 1)") | |
| ax1.set_xlabel("Total Importance") | |
| # Top 15 individual features | |
| sorted_idx = np.argsort(importances)[::-1][:15] | |
| top_names = [feature_names[i] for i in sorted_idx] | |
| top_vals = importances[sorted_idx] | |
| ax2.barh(range(len(top_names)), top_vals) | |
| ax2.set_yticks(range(len(top_names))) | |
| ax2.set_yticklabels(top_names) | |
| ax2.set_title("Top 15 Individual Features") | |
| ax2.set_xlabel("Importance") | |
| ax2.invert_yaxis() | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "channel_importance.png"), dpi=150) | |
| plt.close() | |
| print(" Saved channel_importance.png") | |
| def plot_smoothing_comparison(): | |
| """Side-by-side raw vs smoothed predictions for a test file.""" | |
| print("Generating smoothing comparison plot...") | |
| from preprocess import bandpass_filter, extract_active_segment, normalize_channels, segment_windows | |
| from features import extract_psd_features, extract_stat_features, extract_cross_channel_features | |
| from smoothing import MajorityVoteSmoother, HysteresisFilter, ConfidenceGate | |
| test_files = sorted(DATA_DIR.glob("*.npz")) | |
| # Use a couple different files for variety | |
| sample_files = test_files[:3] | |
| fig, axes = plt.subplots(len(sample_files), 2, figsize=(16, 4 * len(sample_files))) | |
| fig.suptitle("Smoothing Effect: Raw vs Smoothed Predictions", fontsize=14) | |
| stage1 = joblib.load(str(MODELS_DIR / "stage1_binary.pkl")) | |
| stage2 = joblib.load(str(MODELS_DIR / "stage2_direction.pkl")) | |
| DIRECTION_TO_ACTION = {0: "FORWARD", 1: "LEFT", 2: "RIGHT"} | |
| action_to_num = {"STOP": 0, "LEFT": 1, "FORWARD": 2, "RIGHT": 3} | |
| for file_idx, fpath in enumerate(sample_files): | |
| arr = np.load(str(fpath), allow_pickle=True) | |
| label_info = arr["label"].item() | |
| gt = label_info["label"] | |
| eeg = arr["feature_eeg"] | |
| eeg_filt = bandpass_filter(eeg) | |
| if np.any(np.isnan(eeg_filt)): | |
| continue | |
| eeg_active = extract_active_segment(eeg_filt, label_info["duration"]) | |
| eeg_norm = normalize_channels(eeg_active) | |
| windows = segment_windows(eeg_norm, 500, 250) | |
| raw_actions = [] | |
| smoothed_actions = [] | |
| smoother = MajorityVoteSmoother(5) | |
| hysteresis = HysteresisFilter(3) | |
| gate = ConfidenceGate(0.6, 0.4) | |
| for w in windows: | |
| features = np.concatenate([ | |
| extract_psd_features(w), | |
| extract_stat_features(w), | |
| extract_cross_channel_features(w), | |
| ]).reshape(1, -1) | |
| s1_pred = stage1.predict(features)[0] | |
| s1_proba = stage1.predict_proba(features)[0] | |
| s1_active = float(s1_proba[1]) if len(s1_proba) > 1 else float(s1_proba[0]) | |
| s2_pred = 0 | |
| s2_proba = 0.0 | |
| if s1_pred == 1: | |
| s2_pred = int(stage2.predict(features)[0]) | |
| s2_proba = float(np.max(stage2.predict_proba(features)[0])) | |
| raw = gate.decide(s1_active, s1_pred, s2_proba, s2_pred, DIRECTION_TO_ACTION) | |
| raw_actions.append(raw) | |
| s = smoother.update(raw) | |
| smoothed_actions.append(hysteresis.update(s)) | |
| t = np.arange(len(raw_actions)) * 0.5 | |
| ax_raw = axes[file_idx, 0] if len(sample_files) > 1 else axes[0] | |
| ax_smooth = axes[file_idx, 1] if len(sample_files) > 1 else axes[1] | |
| raw_nums = [action_to_num.get(a, 0) for a in raw_actions] | |
| smooth_nums = [action_to_num.get(a, 0) for a in smoothed_actions] | |
| ax_raw.step(t, raw_nums, "b-", where="mid") | |
| ax_raw.set_yticks([0, 1, 2, 3]) | |
| ax_raw.set_yticklabels(["STOP", "LEFT", "FWD", "RIGHT"]) | |
| ax_raw.set_title(f"Raw ({fpath.name}, GT: {gt})") | |
| raw_switches = sum(1 for i in range(1, len(raw_actions)) if raw_actions[i] != raw_actions[i-1]) | |
| ax_raw.text(0.02, 0.98, f"Switches: {raw_switches}", transform=ax_raw.transAxes, | |
| va="top", fontsize=9, bbox=dict(boxstyle="round", facecolor="wheat")) | |
| ax_smooth.step(t, smooth_nums, "r-", where="mid", linewidth=2) | |
| ax_smooth.set_yticks([0, 1, 2, 3]) | |
| ax_smooth.set_yticklabels(["STOP", "LEFT", "FWD", "RIGHT"]) | |
| ax_smooth.set_title(f"Smoothed ({fpath.name}, GT: {gt})") | |
| smooth_switches = sum(1 for i in range(1, len(smoothed_actions)) if smoothed_actions[i] != smoothed_actions[i-1]) | |
| ax_smooth.text(0.02, 0.98, f"Switches: {smooth_switches}", transform=ax_smooth.transAxes, | |
| va="top", fontsize=9, bbox=dict(boxstyle="round", facecolor="wheat")) | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "smoothing_comparison.png"), dpi=150) | |
| plt.close() | |
| print(" Saved smoothing_comparison.png") | |
| def plot_channel_layout(): | |
| """Simple head diagram showing 6 channel positions.""" | |
| print("Generating channel layout diagram...") | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| # Draw head outline | |
| theta = np.linspace(0, 2 * np.pi, 100) | |
| ax.plot(np.cos(theta), np.sin(theta), "k-", linewidth=2) | |
| # Nose | |
| ax.plot([0, 0.1, 0], [1, 1.15, 1], "k-", linewidth=2) | |
| # Ears | |
| ax.plot([-1, -1.1, -1], [0.1, 0, -0.1], "k-", linewidth=2) | |
| ax.plot([1, 1.1, 1], [0.1, 0, -0.1], "k-", linewidth=2) | |
| # Channel positions (approximate on 10-20 system) | |
| channels = { | |
| "AFF6": (0.35, 0.75, "red", "Right anterior frontal"), | |
| "AFp2": (0.15, 0.85, "red", "Right anterior frontopolar"), | |
| "AFp1": (-0.15, 0.85, "blue", "Left anterior frontopolar"), | |
| "AFF5": (-0.35, 0.75, "blue", "Left anterior frontal"), | |
| "FCz": (0.0, 0.3, "green", "Midline frontocentral"), | |
| "CPz": (0.0, -0.1, "green", "Midline centroparietal"), | |
| } | |
| for name, (x, y, color, desc) in channels.items(): | |
| ax.plot(x, y, "o", markersize=20, color=color, zorder=5) | |
| ax.text(x, y, name, ha="center", va="center", fontsize=8, fontweight="bold", zorder=6) | |
| ax.text(x, y - 0.12, desc, ha="center", va="top", fontsize=6, color="gray") | |
| ax.set_xlim(-1.3, 1.3) | |
| ax.set_ylim(-1.3, 1.3) | |
| ax.set_aspect("equal") | |
| ax.set_title("EEG Channel Layout (6 channels)\nRed=Right, Blue=Left, Green=Midline", fontsize=12) | |
| ax.axis("off") | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "channel_layout.png"), dpi=150) | |
| plt.close() | |
| print(" Saved channel_layout.png") | |
| def plot_cross_subject_accuracy(): | |
| """Leave-one-subject-out accuracy for Stage 1.""" | |
| print("Generating cross-subject accuracy plot...") | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.metrics import accuracy_score | |
| data = np.load(str(PROJECT_ROOT / "features.npz"), allow_pickle=True) | |
| X = data["X"] | |
| y_str = data["y"] | |
| subjects = data["subjects"] | |
| y_binary = np.array([0 if s == "Relax" else 1 for s in y_str]) | |
| unique_subjects = sorted(set(subjects)) | |
| subject_accuracies = {} | |
| for test_subj in unique_subjects: | |
| train_mask = subjects != test_subj | |
| test_mask = subjects == test_subj | |
| model = Pipeline([ | |
| ("scaler", StandardScaler()), | |
| ("clf", RandomForestClassifier(n_estimators=100, class_weight="balanced", random_state=42)) | |
| ]) | |
| model.fit(X[train_mask], y_binary[train_mask]) | |
| y_pred = model.predict(X[test_mask]) | |
| acc = accuracy_score(y_binary[test_mask], y_pred) | |
| subject_accuracies[test_subj] = acc | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| subjects_list = list(subject_accuracies.keys()) | |
| accs = list(subject_accuracies.values()) | |
| bars = ax.bar(range(len(subjects_list)), accs, color="steelblue") | |
| ax.set_xticks(range(len(subjects_list))) | |
| ax.set_xticklabels([s[:8] for s in subjects_list], rotation=45) | |
| ax.set_ylabel("Accuracy") | |
| ax.set_title(f"Leave-One-Subject-Out Accuracy (Stage 1 Binary)\nMean: {np.mean(accs):.3f}") | |
| ax.axhline(np.mean(accs), color="red", linestyle="--", label=f"Mean: {np.mean(accs):.3f}") | |
| ax.axhline(0.5, color="gray", linestyle=":", label="Random baseline") | |
| ax.legend() | |
| ax.set_ylim(0, 1) | |
| for bar, acc in zip(bars, accs): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, | |
| f"{acc:.2f}", ha="center", fontsize=9) | |
| plt.tight_layout() | |
| plt.savefig(str(RESULTS_DIR / "cross_subject_accuracy.png"), dpi=150) | |
| plt.close() | |
| print(" Saved cross_subject_accuracy.png") | |
| def generate_all(): | |
| """Generate all visualization plots.""" | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| plot_psd_per_class() | |
| plot_tsne() | |
| plot_temporal_timeline() | |
| plot_channel_importance() | |
| plot_smoothing_comparison() | |
| plot_channel_layout() | |
| plot_cross_subject_accuracy() | |
| print("\nAll visualizations generated in results/") | |
| if __name__ == "__main__": | |
| generate_all() | |