File size: 11,923 Bytes
f2eef96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import os, json
import numpy as np
import pandas as pd
from scipy.signal import welch
from scipy.stats import skew, kurtosis
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, roc_auc_score
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from joblib import dump  # ใช้สำหรับบันทึก model


# ======================== DATA LOADING ========================
def load_tremor_data(base_path, folders):
    """โหลดข้อมูล tremor จากไฟล์ JSON ทั้ง format เก่าและใหม่"""
    all_data = []

    for folder, label in folders.items():
        folder_path = os.path.join(base_path, folder)
        print(f"📂 Loading folder: {folder_path}")

        for file_name in os.listdir(folder_path):
            if not file_name.endswith(".json"):
                continue

            file_path = os.path.join(folder_path, file_name)
            try:
                with open(file_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception as e:
                print(f"❌ Error reading {file_name}: {e}")
                continue

            if "recording" in data:
                rec = data["recording"]
            elif "data" in data and "recording" in data["data"]:
                rec = data["data"]["recording"]
            else:
                print(f"⚠️ Skip: {file_name} (no 'recording' field found)")
                continue

            records = rec.get("recordedData", [])
            fmt = rec.get("recordingFormat", [])

            if not records or not fmt or len(records) < 5:
                print(f"⚠️ Skip empty or too short: {file_name}")
                continue

            try:
                df = pd.DataFrame([r["data"] for r in records], columns=fmt)
                df["ts"] = [r.get("ts", None) for r in records]
                df["label"] = label
                df["file"] = file_name
                all_data.append(df)
            except Exception as e:
                print(f"⚠️ Parse error {file_name}: {e}")
                continue

    if not all_data:
        print("❌ No valid files found.")
        return pd.DataFrame()

    df_all = pd.concat(all_data, ignore_index=True)
    print(f"✅ Loaded total rows: {len(df_all)}, files: {len(all_data)}")
    return df_all


# ======================== FEATURE EXTRACTION ========================
def compute_rms(x): return np.sqrt(np.mean(x**2))
def compute_sma(x, y, z): return np.mean(np.abs(x) + np.abs(y) + np.abs(z))
def compute_vector_mag(x, y, z): return np.sqrt(x**2 + y**2 + z**2)
def compute_entropy(signal, bins=30):
    hist, _ = np.histogram(signal, bins=bins, density=True)
    hist = hist[hist > 0]
    return -np.sum(hist * np.log(hist))


def compute_freq_features(signal, fs=50):
    f, Pxx = welch(signal, fs=fs, nperseg=min(256, len(signal)))
    if len(Pxx) == 0:
        return {"dom_freq": 0, "band_power_4_6": 0, "spec_entropy": 0}
    dom_freq = f[np.argmax(Pxx)]
    band_mask = (f >= 4) & (f <= 6)
    band_power = np.trapz(Pxx[band_mask], f[band_mask])
    Pxx_norm = Pxx / np.sum(Pxx)
    spec_entropy = -np.sum(Pxx_norm * np.log(Pxx_norm + 1e-12))
    return {"dom_freq": dom_freq, "band_power_4_6": band_power, "spec_entropy": spec_entropy}


def extract_essential_features(df, fs=50):
    feats = {}
    for sensor in ["ax", "ay", "az", "gx", "gy", "gz"]:
        sig = df[sensor].values
        feats[f"{sensor}_rms"] = compute_rms(sig)
        feats[f"{sensor}_mean"] = np.mean(sig)
        feats[f"{sensor}_std"] = np.std(sig)
        feats[f"{sensor}_skew"] = skew(sig)
        feats[f"{sensor}_kurtosis"] = kurtosis(sig)
        feats[f"{sensor}_entropy"] = compute_entropy(sig)
        f_feats = compute_freq_features(sig, fs)
        for k, v in f_feats.items():
            feats[f"{sensor}_{k}"] = v

    feats["acc_sma"] = compute_sma(df["ax"], df["ay"], df["az"])
    feats["gyro_sma"] = compute_sma(df["gx"], df["gy"], df["gz"])
    feats["acc_gyro_corr"] = np.corrcoef(
        compute_vector_mag(df["ax"], df["ay"], df["az"]),
        compute_vector_mag(df["gx"], df["gy"], df["gz"])
    )[0, 1]

    feats["label"] = df["label"].iloc[0]
    feats["file"] = df["file"].iloc[0]
    return feats


def create_feature_dataset(df_all, fs=50):
    features = [extract_essential_features(g, fs) for _, g in df_all.groupby("file")]
    return pd.DataFrame(features)

# ======================== VISUALIZATION FUNCTIONS ========================
def plot_pca_clustering(df_features, X_scaled, model):
    """

    Plot PCA clustering visualization

    

    Parameters:

    - df_features: DataFrame ของคุณลักษณะ

    - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling

    - model: โมเดลที่ฝึกแล้ว

    

    Returns:

    - pca: PCA object

    - df_plot: DataFrame สำหรับ plotting

    """
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X_scaled)
    
    # สร้าง DataFrame สำหรับ plotting
    df_plot = df_features.copy()
    df_plot["pca1"] = X_pca[:, 0]
    df_plot["pca2"] = X_pca[:, 1]
    df_plot["pred"] = model.predict(X_scaled)
    
    plt.figure(figsize=(8, 6))
    sns.scatterplot(
        data=df_plot,
        x="pca1", y="pca2",
        hue="label", style="pred",
        palette={"normal": "#4CAF50", "pd": "#E91E63"},
        s=90, alpha=0.9
    )
    plt.title("🧩 PCA Clustering Visualization (PD vs Normal)", fontsize=14)
    plt.xlabel("PCA 1")
    plt.ylabel("PCA 2")
    plt.legend(title="Label / Prediction")
    plt.show()
    
    return pca, df_plot

def plot_pca_biplot(df_features, X_scaled, X, pca=None):
    """

    Plot PCA biplot with feature loading vectors

    

    Parameters:

    - df_features: DataFrame ของคุณลักษณะ

    - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling

    - X: ข้อมูลคุณลักษณะดั้งเดิม

    - pca: PCA object (ถ้ามี)

    

    Returns:

    - loadings: DataFrame ของ loading vectors

    - df_plot: DataFrame สำหรับ plotting

    """
    if pca is None:
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(X_scaled)
    else:
        X_pca = pca.transform(X_scaled)
    
    # สร้าง DataFrame สำหรับ plotting
    df_plot = df_features.copy()
    df_plot["pca1"] = X_pca[:, 0]
    df_plot["pca2"] = X_pca[:, 1]
    
    loadings = pd.DataFrame(
        pca.components_.T,
        columns=['PCA1', 'PCA2'],
        index=X.columns
    )
    
    # แสดง top feature ที่มีผลต่อ PCA1 และ PCA2
    print("\n📊 Top 10 features influencing PCA1:")
    print(loadings['PCA1'].sort_values(ascending=False).head(10))
    print("\n📊 Top 10 features influencing PCA2:")
    print(loadings['PCA2'].sort_values(ascending=False).head(10))
    
    # Plot loading vectors (Biplot)
    plt.figure(figsize=(10, 8))
    sns.scatterplot(
        data=df_plot,
        x="pca1", y="pca2",
        hue="label",
        palette={"normal": "#4CAF50", "pd": "#E91E63"},
        s=80, alpha=0.9
    )
    
    # เพิ่ม loading vectors
    for i in range(len(loadings)):
        plt.arrow(0, 0, loadings.PCA1[i]*10, loadings.PCA2[i]*10, 
                  color='gray', alpha=0.5, head_width=0.3)
        plt.text(loadings.PCA1[i]*11, loadings.PCA2[i]*11, 
                 loadings.index[i], fontsize=8, color='black')
    
    plt.title("📈 PCA Biplot: Feature Loading Direction", fontsize=13)
    plt.xlabel("PCA 1")
    plt.ylabel("PCA 2")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return loadings, df_plot

def plot_roc_curve(y_true, y_proba, model_name="Random Forest"):
    """

    Plot ROC curve

    

    Parameters:

    - y_true: ค่าเป้าหมายจริง

    - y_proba: ความน่าจะเป็นที่ทำนาย

    - model_name: ชื่อโมเดล

    

    Returns:

    - roc_auc: ROC AUC score

    - fpr: False Positive Rates

    - tpr: True Positive Rates

    """
    fpr, tpr, thresholds = roc_curve(y_true, y_proba)
    roc_auc = roc_auc_score(y_true, y_proba)
    
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, color="#E91E63", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], color="gray", linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"🧩 ROC Curve – {model_name} (PD vs Normal)")
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return roc_auc, fpr, tpr

def plot_shap_analysis(model, X_scaled, X, plot_type="both"):
    """

    SHAP analysis และ visualization

    

    Parameters:

    - model: โมเดลที่ฝึกแล้ว

    - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling

    - X: ข้อมูลคุณลักษณะดั้งเดิม

    - plot_type: ประเภท plot ("bar", "beeswarm", "both")

    

    Returns:

    - explainer: SHAP explainer

    - shap_values: SHAP values

    """
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_scaled)
    
    if plot_type in ["bar", "both"]:
        shap.summary_plot(shap_values[1], X, plot_type="bar", show=False)
        plt.title("SHAP Feature Importance (Bar Plot)")
        plt.tight_layout()
        plt.show()
    
    if plot_type in ["beeswarm", "both"]:
        shap.summary_plot(shap_values[1], X, show=False)
        plt.title("SHAP Feature Importance (Beeswarm Plot)")
        plt.tight_layout()
        plt.show()
    
    return explainer, shap_values


# ======================== MODEL TRAINING ========================
def train_random_forest(X, y, n_estimators=300, max_depth=6, random_state=42):
    """ฝึก RandomForest พร้อมจัดการ NaN ใน y"""
    df_tmp = pd.DataFrame(X).copy()
    df_tmp["label"] = y
    df_tmp = df_tmp.dropna(subset=["label"])
    df_tmp = df_tmp.dropna(axis=0, how="any")

    y_clean = df_tmp["label"].values
    X_clean = df_tmp.drop(columns=["label"]).values

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_clean)

    model = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        random_state=random_state,
    )
    model.fit(X_scaled, y_clean)
    print(f"✅ Training complete ({len(y_clean)} samples used)")
    return model, scaler, X_scaled


def evaluate_model(model, X_scaled, y_true):
    y_pred = model.predict(X_scaled)
    y_proba = model.predict_proba(X_scaled)[:, 1]

    print("\nConfusion Matrix:")
    print(confusion_matrix(y_true, y_pred))
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=["Normal", "PD"]))

    return y_pred, y_proba


# ======================== SAVE MODEL ========================
def save_rf_model(model, scaler, feature_names, base_path):
    model_dict = {
        "model": model,
        "scaler": scaler,
        "features": feature_names
    }
    save_path = os.path.join(base_path, "tremor_rf_model.joblib")
    dump(model_dict, save_path)
    print(f"💾 Model saved to {save_path}")
    return save_path