File size: 7,758 Bytes
832948a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Temporal Embedding: Visualize how brain activity evolves within a recording.
One file per intent type, compute features per window, PCA to 2D, draw trajectory
with arrows colored by phase (rest->initiation->sustained).
Saves results/intent_evolution.png."""
import numpy as np
from pathlib import Path
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

from preprocess import bandpass_filter, extract_active_segment, normalize_channels, segment_windows
from features import extract_psd_features, extract_stat_features, extract_cross_channel_features

PROJECT_ROOT = Path(__file__).resolve().parent.parent
RESULTS_DIR = PROJECT_ROOT / "results"

# One representative file per intent type
INTENT_FILES = {
    "Left Fist":       "2562e7bd-14.npz",
    "Right Fist":      "0b2dbd41-34.npz",
    "Both Fists":      "4787dfb9-10.npz",
    "Tongue Tapping":  "0b2dbd41-16.npz",
    "Relax":           "2161ecb6-12.npz",
}

INTENT_COLORS = {
    "Left Fist":       "#4285f4",
    "Right Fist":      "#ea4335",
    "Both Fists":      "#34a853",
    "Tongue Tapping":  "#fbbc04",
    "Relax":           "#9aa0a6",
}

INTENT_ACTIONS = {
    "Left Fist":       "LEFT",
    "Right Fist":      "RIGHT",
    "Both Fists":      "FORWARD",
    "Tongue Tapping":  "BACKWARD",
    "Relax":           "STOP",
}

PHASE_MARKERS = {
    "rest":        "o",
    "initiation":  "^",
    "sustained":   "s",
}


def extract_windows_and_features(npz_path):
    """Extract per-window features from a single file, return features + window indices."""
    arr = np.load(str(npz_path), allow_pickle=True)
    eeg = arr["feature_eeg"]
    info = arr["label"].item()

    filtered = bandpass_filter(eeg)
    if np.any(np.isnan(filtered)) or np.any(np.isinf(filtered)):
        return None, None
    active = extract_active_segment(filtered, info["duration"])
    normed = normalize_channels(active)
    windows = segment_windows(normed, 500, 250)

    feats = []
    for w in windows:
        feat = np.concatenate([
            extract_psd_features(w),
            extract_stat_features(w),
            extract_cross_channel_features(w),
        ])
        feats.append(feat)

    return np.array(feats), info


def assign_phases(n_windows, label):
    """Assign phase labels to each window based on position and label."""
    phases = []
    for i in range(n_windows):
        if label == "Relax":
            phases.append("rest")
        elif i < 2:
            phases.append("rest")
        elif i == 2:
            phases.append("initiation")
        else:
            phases.append("sustained")
    return phases


def main():
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    data_dir = PROJECT_ROOT / "data"

    print("=" * 70)
    print("  Temporal Embedding: Intent Evolution Visualization")
    print("=" * 70)

    # Collect all features for PCA fitting
    all_features = []
    intent_data = {}

    for label, fname in INTENT_FILES.items():
        fpath = data_dir / fname
        feats, info = extract_windows_and_features(fpath)
        if feats is None:
            print(f"  Skipping {fname} (filter error)")
            continue

        phases = assign_phases(len(feats), label)
        intent_data[label] = {"features": feats, "phases": phases, "info": info}
        all_features.append(feats)
        print(f"  {label:17s}: {len(feats)} windows from {fname}")

    # Fit PCA on all windows jointly
    X_all = np.vstack(all_features)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_all)
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X_scaled)

    print(f"\n  PCA explained variance: {pca.explained_variance_ratio_[0]:.1%}, {pca.explained_variance_ratio_[1]:.1%}")

    # Split back into per-intent
    idx = 0
    for label in intent_data:
        n = len(intent_data[label]["features"])
        intent_data[label]["pca"] = X_2d[idx:idx + n]
        idx += n

    # Plot
    fig, ax = plt.subplots(figsize=(12, 9))
    fig.patch.set_facecolor("#1a1a2e")
    ax.set_facecolor("#16213e")

    phase_colors = {
        "rest":        "#666666",
        "initiation":  "#ffffff",
        "sustained":   None,  # use intent color
    }

    for label, data in intent_data.items():
        pts = data["pca"]
        phases = data["phases"]
        color = INTENT_COLORS[label]
        action = INTENT_ACTIONS[label]

        # Draw trajectory line
        ax.plot(pts[:, 0], pts[:, 1], "-", color=color, alpha=0.3, linewidth=1.5)

        # Draw arrows along trajectory
        for i in range(len(pts) - 1):
            dx = pts[i+1, 0] - pts[i, 0]
            dy = pts[i+1, 1] - pts[i, 1]
            ax.annotate("", xy=(pts[i+1, 0], pts[i+1, 1]),
                        xytext=(pts[i, 0], pts[i, 1]),
                        arrowprops=dict(arrowstyle="->", color=color, alpha=0.5, lw=1.2))

        # Draw points colored by phase
        for i, (pt, phase) in enumerate(zip(pts, phases)):
            marker = PHASE_MARKERS[phase]
            if phase == "sustained":
                fc = color
            elif phase == "initiation":
                fc = "white"
            else:
                fc = "#666666"
            ec = color
            size = 80 if phase == "initiation" else 40
            ax.scatter(pt[0], pt[1], marker=marker, s=size, c=fc,
                      edgecolors=ec, linewidths=1.5, zorder=5)

        # Label start and end
        ax.annotate(f"{action}\nstart", xy=(pts[0, 0], pts[0, 1]),
                   fontsize=8, color=color, fontweight="bold",
                   ha="center", va="bottom",
                   xytext=(0, 8), textcoords="offset points")
        ax.annotate(f"end", xy=(pts[-1, 0], pts[-1, 1]),
                   fontsize=8, color=color,
                   ha="center", va="top",
                   xytext=(0, -8), textcoords="offset points")

    # Legend for intents
    for label, color in INTENT_COLORS.items():
        action = INTENT_ACTIONS[label]
        ax.plot([], [], "o-", color=color, markersize=6, label=f"{label} -> {action}")

    # Legend for phases
    ax.scatter([], [], marker="o", s=40, c="#666666", edgecolors="white",
              label="Rest phase")
    ax.scatter([], [], marker="^", s=80, c="white", edgecolors="white",
              label="Initiation")
    ax.scatter([], [], marker="s", s=40, c="#aaaaaa", edgecolors="white",
              label="Sustained")

    legend = ax.legend(loc="upper left", fontsize=10, facecolor="#16213e",
                      edgecolor="#444", labelcolor="white", ncol=2)

    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)",
                 fontsize=12, color="white")
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)",
                 fontsize=12, color="white")
    ax.set_title("Intent Evolution: How Brain Activity Changes Within a Recording",
                fontsize=14, fontweight="bold", color="white", pad=15)

    ax.tick_params(colors="white")
    for spine in ax.spines.values():
        spine.set_color("#444")

    # Subtitle
    fig.text(0.5, 0.01,
             "Each trajectory shows one EEG recording projected to feature space. "
             "Arrows show temporal progression. Phases: rest (circle) -> initiation (triangle) -> sustained (square).",
             ha="center", fontsize=9, color="#aaaaaa")

    plt.tight_layout(rect=[0, 0.03, 1, 1])
    out_path = str(RESULTS_DIR / "intent_evolution.png")
    plt.savefig(out_path, dpi=150, facecolor=fig.get_facecolor(), bbox_inches="tight")
    plt.close()
    print(f"\n  Saved to {out_path}")


if __name__ == "__main__":
    main()