tremor-post-pd-api / tremor_analysis_functions.py
phoner45's picture
Upload 4 files
f2eef96 verified
import os, json
import numpy as np
import pandas as pd
from scipy.signal import welch
from scipy.stats import skew, kurtosis
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, roc_auc_score
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from joblib import dump # ใช้สำหรับบันทึก model
# ======================== DATA LOADING ========================
def load_tremor_data(base_path, folders):
"""โหลดข้อมูล tremor จากไฟล์ JSON ทั้ง format เก่าและใหม่"""
all_data = []
for folder, label in folders.items():
folder_path = os.path.join(base_path, folder)
print(f"📂 Loading folder: {folder_path}")
for file_name in os.listdir(folder_path):
if not file_name.endswith(".json"):
continue
file_path = os.path.join(folder_path, file_name)
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception as e:
print(f"❌ Error reading {file_name}: {e}")
continue
if "recording" in data:
rec = data["recording"]
elif "data" in data and "recording" in data["data"]:
rec = data["data"]["recording"]
else:
print(f"⚠️ Skip: {file_name} (no 'recording' field found)")
continue
records = rec.get("recordedData", [])
fmt = rec.get("recordingFormat", [])
if not records or not fmt or len(records) < 5:
print(f"⚠️ Skip empty or too short: {file_name}")
continue
try:
df = pd.DataFrame([r["data"] for r in records], columns=fmt)
df["ts"] = [r.get("ts", None) for r in records]
df["label"] = label
df["file"] = file_name
all_data.append(df)
except Exception as e:
print(f"⚠️ Parse error {file_name}: {e}")
continue
if not all_data:
print("❌ No valid files found.")
return pd.DataFrame()
df_all = pd.concat(all_data, ignore_index=True)
print(f"✅ Loaded total rows: {len(df_all)}, files: {len(all_data)}")
return df_all
# ======================== FEATURE EXTRACTION ========================
def compute_rms(x): return np.sqrt(np.mean(x**2))
def compute_sma(x, y, z): return np.mean(np.abs(x) + np.abs(y) + np.abs(z))
def compute_vector_mag(x, y, z): return np.sqrt(x**2 + y**2 + z**2)
def compute_entropy(signal, bins=30):
hist, _ = np.histogram(signal, bins=bins, density=True)
hist = hist[hist > 0]
return -np.sum(hist * np.log(hist))
def compute_freq_features(signal, fs=50):
f, Pxx = welch(signal, fs=fs, nperseg=min(256, len(signal)))
if len(Pxx) == 0:
return {"dom_freq": 0, "band_power_4_6": 0, "spec_entropy": 0}
dom_freq = f[np.argmax(Pxx)]
band_mask = (f >= 4) & (f <= 6)
band_power = np.trapz(Pxx[band_mask], f[band_mask])
Pxx_norm = Pxx / np.sum(Pxx)
spec_entropy = -np.sum(Pxx_norm * np.log(Pxx_norm + 1e-12))
return {"dom_freq": dom_freq, "band_power_4_6": band_power, "spec_entropy": spec_entropy}
def extract_essential_features(df, fs=50):
feats = {}
for sensor in ["ax", "ay", "az", "gx", "gy", "gz"]:
sig = df[sensor].values
feats[f"{sensor}_rms"] = compute_rms(sig)
feats[f"{sensor}_mean"] = np.mean(sig)
feats[f"{sensor}_std"] = np.std(sig)
feats[f"{sensor}_skew"] = skew(sig)
feats[f"{sensor}_kurtosis"] = kurtosis(sig)
feats[f"{sensor}_entropy"] = compute_entropy(sig)
f_feats = compute_freq_features(sig, fs)
for k, v in f_feats.items():
feats[f"{sensor}_{k}"] = v
feats["acc_sma"] = compute_sma(df["ax"], df["ay"], df["az"])
feats["gyro_sma"] = compute_sma(df["gx"], df["gy"], df["gz"])
feats["acc_gyro_corr"] = np.corrcoef(
compute_vector_mag(df["ax"], df["ay"], df["az"]),
compute_vector_mag(df["gx"], df["gy"], df["gz"])
)[0, 1]
feats["label"] = df["label"].iloc[0]
feats["file"] = df["file"].iloc[0]
return feats
def create_feature_dataset(df_all, fs=50):
features = [extract_essential_features(g, fs) for _, g in df_all.groupby("file")]
return pd.DataFrame(features)
# ======================== VISUALIZATION FUNCTIONS ========================
def plot_pca_clustering(df_features, X_scaled, model):
"""
Plot PCA clustering visualization
Parameters:
- df_features: DataFrame ของคุณลักษณะ
- X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
- model: โมเดลที่ฝึกแล้ว
Returns:
- pca: PCA object
- df_plot: DataFrame สำหรับ plotting
"""
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
# สร้าง DataFrame สำหรับ plotting
df_plot = df_features.copy()
df_plot["pca1"] = X_pca[:, 0]
df_plot["pca2"] = X_pca[:, 1]
df_plot["pred"] = model.predict(X_scaled)
plt.figure(figsize=(8, 6))
sns.scatterplot(
data=df_plot,
x="pca1", y="pca2",
hue="label", style="pred",
palette={"normal": "#4CAF50", "pd": "#E91E63"},
s=90, alpha=0.9
)
plt.title("🧩 PCA Clustering Visualization (PD vs Normal)", fontsize=14)
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.legend(title="Label / Prediction")
plt.show()
return pca, df_plot
def plot_pca_biplot(df_features, X_scaled, X, pca=None):
"""
Plot PCA biplot with feature loading vectors
Parameters:
- df_features: DataFrame ของคุณลักษณะ
- X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
- X: ข้อมูลคุณลักษณะดั้งเดิม
- pca: PCA object (ถ้ามี)
Returns:
- loadings: DataFrame ของ loading vectors
- df_plot: DataFrame สำหรับ plotting
"""
if pca is None:
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
else:
X_pca = pca.transform(X_scaled)
# สร้าง DataFrame สำหรับ plotting
df_plot = df_features.copy()
df_plot["pca1"] = X_pca[:, 0]
df_plot["pca2"] = X_pca[:, 1]
loadings = pd.DataFrame(
pca.components_.T,
columns=['PCA1', 'PCA2'],
index=X.columns
)
# แสดง top feature ที่มีผลต่อ PCA1 และ PCA2
print("\n📊 Top 10 features influencing PCA1:")
print(loadings['PCA1'].sort_values(ascending=False).head(10))
print("\n📊 Top 10 features influencing PCA2:")
print(loadings['PCA2'].sort_values(ascending=False).head(10))
# Plot loading vectors (Biplot)
plt.figure(figsize=(10, 8))
sns.scatterplot(
data=df_plot,
x="pca1", y="pca2",
hue="label",
palette={"normal": "#4CAF50", "pd": "#E91E63"},
s=80, alpha=0.9
)
# เพิ่ม loading vectors
for i in range(len(loadings)):
plt.arrow(0, 0, loadings.PCA1[i]*10, loadings.PCA2[i]*10,
color='gray', alpha=0.5, head_width=0.3)
plt.text(loadings.PCA1[i]*11, loadings.PCA2[i]*11,
loadings.index[i], fontsize=8, color='black')
plt.title("📈 PCA Biplot: Feature Loading Direction", fontsize=13)
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.grid(True, alpha=0.3)
plt.show()
return loadings, df_plot
def plot_roc_curve(y_true, y_proba, model_name="Random Forest"):
"""
Plot ROC curve
Parameters:
- y_true: ค่าเป้าหมายจริง
- y_proba: ความน่าจะเป็นที่ทำนาย
- model_name: ชื่อโมเดล
Returns:
- roc_auc: ROC AUC score
- fpr: False Positive Rates
- tpr: True Positive Rates
"""
fpr, tpr, thresholds = roc_curve(y_true, y_proba)
roc_auc = roc_auc_score(y_true, y_proba)
plt.figure(figsize=(6, 6))
plt.plot(fpr, tpr, color="#E91E63", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="gray", linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"🧩 ROC Curve – {model_name} (PD vs Normal)")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.show()
return roc_auc, fpr, tpr
def plot_shap_analysis(model, X_scaled, X, plot_type="both"):
"""
SHAP analysis และ visualization
Parameters:
- model: โมเดลที่ฝึกแล้ว
- X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
- X: ข้อมูลคุณลักษณะดั้งเดิม
- plot_type: ประเภท plot ("bar", "beeswarm", "both")
Returns:
- explainer: SHAP explainer
- shap_values: SHAP values
"""
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_scaled)
if plot_type in ["bar", "both"]:
shap.summary_plot(shap_values[1], X, plot_type="bar", show=False)
plt.title("SHAP Feature Importance (Bar Plot)")
plt.tight_layout()
plt.show()
if plot_type in ["beeswarm", "both"]:
shap.summary_plot(shap_values[1], X, show=False)
plt.title("SHAP Feature Importance (Beeswarm Plot)")
plt.tight_layout()
plt.show()
return explainer, shap_values
# ======================== MODEL TRAINING ========================
def train_random_forest(X, y, n_estimators=300, max_depth=6, random_state=42):
"""ฝึก RandomForest พร้อมจัดการ NaN ใน y"""
df_tmp = pd.DataFrame(X).copy()
df_tmp["label"] = y
df_tmp = df_tmp.dropna(subset=["label"])
df_tmp = df_tmp.dropna(axis=0, how="any")
y_clean = df_tmp["label"].values
X_clean = df_tmp.drop(columns=["label"]).values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_clean)
model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=random_state,
)
model.fit(X_scaled, y_clean)
print(f"✅ Training complete ({len(y_clean)} samples used)")
return model, scaler, X_scaled
def evaluate_model(model, X_scaled, y_true):
y_pred = model.predict(X_scaled)
y_proba = model.predict_proba(X_scaled)[:, 1]
print("\nConfusion Matrix:")
print(confusion_matrix(y_true, y_pred))
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=["Normal", "PD"]))
return y_pred, y_proba
# ======================== SAVE MODEL ========================
def save_rf_model(model, scaler, feature_names, base_path):
model_dict = {
"model": model,
"scaler": scaler,
"features": feature_names
}
save_path = os.path.join(base_path, "tremor_rf_model.joblib")
dump(model_dict, save_path)
print(f"💾 Model saved to {save_path}")
return save_path