phoner45 commited on
Commit
f2eef96
·
verified ·
1 Parent(s): 5e4ba18

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +109 -0
  2. dockerfile +17 -0
  3. requirements.txt +11 -0
  4. tremor_analysis_functions.py +331 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import joblib, requests, os, json, io, tempfile
5
+ import pandas as pd
6
+ import numpy as np
7
+ from tremor_analysis_functions import extract_essential_features
8
+
9
+ # =====================================================
10
+ # CONFIG
11
+ # =====================================================
12
+ MODEL_REPO = "Chula-PD/tremor-post" # 👈 เปลี่ยนชื่อ repo ตามจริง
13
+ MODEL_FILE = "tremor_rf_model.joblib"
14
+ MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
15
+
16
+ # =====================================================
17
+ # INIT FastAPI
18
+ # =====================================================
19
+ app = FastAPI(title="CheckPD Tremor API", version="1.0")
20
+
21
+ # Allow CORS (เชื่อมต่อจาก React หรือ Streamlit frontend ได้)
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ # =====================================================
31
+ # LOAD MODEL
32
+ # =====================================================
33
+ def load_model():
34
+ """โหลด joblib model จาก Hugging Face"""
35
+ if not os.path.exists(MODEL_FILE):
36
+ print("⬇️ Downloading model from Hugging Face...")
37
+ r = requests.get(MODEL_URL)
38
+ with open(MODEL_FILE, "wb") as f:
39
+ f.write(r.content)
40
+ model_dict = joblib.load(MODEL_FILE)
41
+ print("✅ Model loaded successfully.")
42
+ return model_dict
43
+
44
+ model_dict = load_model()
45
+ model = model_dict["model"]
46
+ scaler = model_dict["scaler"]
47
+ features = model_dict["features"]
48
+
49
+ # =====================================================
50
+ # HELPER: JSON Preprocessing
51
+ # =====================================================
52
+ def preprocess_json(json_data):
53
+ """
54
+ แปลงไฟล์ JSON จากมือถือ → feature vector ที่พร้อมสำหรับ model
55
+ """
56
+ if "recording" in json_data:
57
+ rec = json_data["recording"]
58
+ elif "data" in json_data and "recording" in json_data["data"]:
59
+ rec = json_data["data"]["recording"]
60
+ else:
61
+ raise ValueError("Invalid JSON format: missing 'recording' field")
62
+
63
+ records = rec.get("recordedData", [])
64
+ fmt = rec.get("recordingFormat", [])
65
+
66
+ if not records or not fmt:
67
+ raise ValueError("Incomplete recording data")
68
+
69
+ df = pd.DataFrame([r["data"] for r in records], columns=fmt)
70
+ df["label"] = "unknown"
71
+ df["file"] = "uploaded"
72
+
73
+ feats = extract_essential_features(df)
74
+ feat_df = pd.DataFrame([feats]).drop(columns=["label", "file"], errors="ignore")
75
+
76
+ # ✅ align feature order
77
+ X = feat_df.reindex(columns=features, fill_value=0)
78
+ X_scaled = scaler.transform(X)
79
+ return X_scaled
80
+
81
+ # =====================================================
82
+ # ENDPOINTS
83
+ # =====================================================
84
+ @app.get("/")
85
+ def home():
86
+ return {"message": "CheckPD Tremor API is running 🚀"}
87
+
88
+ @app.post("/predict")
89
+ async def predict(file: UploadFile = File(...)):
90
+ """
91
+ รับไฟล์ JSON จาก UI แล้ว predict PD/Normal
92
+ """
93
+ try:
94
+ contents = await file.read()
95
+ json_data = json.loads(contents.decode("utf-8"))
96
+ X_scaled = preprocess_json(json_data)
97
+
98
+ y_pred = model.predict(X_scaled)[0]
99
+ y_proba = model.predict_proba(X_scaled)[0][1]
100
+
101
+ result = {
102
+ "prediction": "PD" if y_pred == 1 else "Normal",
103
+ "probability_pd": round(float(y_proba), 4),
104
+ "file_name": file.filename
105
+ }
106
+ return result
107
+
108
+ except Exception as e:
109
+ raise HTTPException(status_code=500, detail=str(e))
dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Base image ---
2
+ FROM python:3.10-slim
3
+
4
+ # --- Working directory ---
5
+ WORKDIR /app
6
+
7
+ # --- Copy all files ---
8
+ COPY . /app
9
+
10
+ # --- Install dependencies ---
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # --- Expose port (Hugging Face Spaces expects 7860) ---
14
+ EXPOSE 7860
15
+
16
+ # --- Run FastAPI server ---
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ joblib
4
+ scikit-learn
5
+ pandas
6
+ numpy
7
+ scipy
8
+ shap
9
+ seaborn
10
+ matplotlib
11
+ requests
tremor_analysis_functions.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import numpy as np
3
+ import pandas as pd
4
+ from scipy.signal import welch
5
+ from scipy.stats import skew, kurtosis
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.decomposition import PCA
8
+ from sklearn.ensemble import RandomForestClassifier
9
+ from sklearn.metrics import confusion_matrix, classification_report, roc_curve, roc_auc_score
10
+ import shap
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from joblib import dump # ใช้สำหรับบันทึก model
14
+
15
+
16
+ # ======================== DATA LOADING ========================
17
+ def load_tremor_data(base_path, folders):
18
+ """โหลดข้อมูล tremor จากไฟล์ JSON ทั้ง format เก่าและใหม่"""
19
+ all_data = []
20
+
21
+ for folder, label in folders.items():
22
+ folder_path = os.path.join(base_path, folder)
23
+ print(f"📂 Loading folder: {folder_path}")
24
+
25
+ for file_name in os.listdir(folder_path):
26
+ if not file_name.endswith(".json"):
27
+ continue
28
+
29
+ file_path = os.path.join(folder_path, file_name)
30
+ try:
31
+ with open(file_path, "r", encoding="utf-8") as f:
32
+ data = json.load(f)
33
+ except Exception as e:
34
+ print(f"❌ Error reading {file_name}: {e}")
35
+ continue
36
+
37
+ if "recording" in data:
38
+ rec = data["recording"]
39
+ elif "data" in data and "recording" in data["data"]:
40
+ rec = data["data"]["recording"]
41
+ else:
42
+ print(f"⚠️ Skip: {file_name} (no 'recording' field found)")
43
+ continue
44
+
45
+ records = rec.get("recordedData", [])
46
+ fmt = rec.get("recordingFormat", [])
47
+
48
+ if not records or not fmt or len(records) < 5:
49
+ print(f"⚠️ Skip empty or too short: {file_name}")
50
+ continue
51
+
52
+ try:
53
+ df = pd.DataFrame([r["data"] for r in records], columns=fmt)
54
+ df["ts"] = [r.get("ts", None) for r in records]
55
+ df["label"] = label
56
+ df["file"] = file_name
57
+ all_data.append(df)
58
+ except Exception as e:
59
+ print(f"⚠️ Parse error {file_name}: {e}")
60
+ continue
61
+
62
+ if not all_data:
63
+ print("❌ No valid files found.")
64
+ return pd.DataFrame()
65
+
66
+ df_all = pd.concat(all_data, ignore_index=True)
67
+ print(f"✅ Loaded total rows: {len(df_all)}, files: {len(all_data)}")
68
+ return df_all
69
+
70
+
71
+ # ======================== FEATURE EXTRACTION ========================
72
+ def compute_rms(x): return np.sqrt(np.mean(x**2))
73
+ def compute_sma(x, y, z): return np.mean(np.abs(x) + np.abs(y) + np.abs(z))
74
+ def compute_vector_mag(x, y, z): return np.sqrt(x**2 + y**2 + z**2)
75
+ def compute_entropy(signal, bins=30):
76
+ hist, _ = np.histogram(signal, bins=bins, density=True)
77
+ hist = hist[hist > 0]
78
+ return -np.sum(hist * np.log(hist))
79
+
80
+
81
+ def compute_freq_features(signal, fs=50):
82
+ f, Pxx = welch(signal, fs=fs, nperseg=min(256, len(signal)))
83
+ if len(Pxx) == 0:
84
+ return {"dom_freq": 0, "band_power_4_6": 0, "spec_entropy": 0}
85
+ dom_freq = f[np.argmax(Pxx)]
86
+ band_mask = (f >= 4) & (f <= 6)
87
+ band_power = np.trapz(Pxx[band_mask], f[band_mask])
88
+ Pxx_norm = Pxx / np.sum(Pxx)
89
+ spec_entropy = -np.sum(Pxx_norm * np.log(Pxx_norm + 1e-12))
90
+ return {"dom_freq": dom_freq, "band_power_4_6": band_power, "spec_entropy": spec_entropy}
91
+
92
+
93
+ def extract_essential_features(df, fs=50):
94
+ feats = {}
95
+ for sensor in ["ax", "ay", "az", "gx", "gy", "gz"]:
96
+ sig = df[sensor].values
97
+ feats[f"{sensor}_rms"] = compute_rms(sig)
98
+ feats[f"{sensor}_mean"] = np.mean(sig)
99
+ feats[f"{sensor}_std"] = np.std(sig)
100
+ feats[f"{sensor}_skew"] = skew(sig)
101
+ feats[f"{sensor}_kurtosis"] = kurtosis(sig)
102
+ feats[f"{sensor}_entropy"] = compute_entropy(sig)
103
+ f_feats = compute_freq_features(sig, fs)
104
+ for k, v in f_feats.items():
105
+ feats[f"{sensor}_{k}"] = v
106
+
107
+ feats["acc_sma"] = compute_sma(df["ax"], df["ay"], df["az"])
108
+ feats["gyro_sma"] = compute_sma(df["gx"], df["gy"], df["gz"])
109
+ feats["acc_gyro_corr"] = np.corrcoef(
110
+ compute_vector_mag(df["ax"], df["ay"], df["az"]),
111
+ compute_vector_mag(df["gx"], df["gy"], df["gz"])
112
+ )[0, 1]
113
+
114
+ feats["label"] = df["label"].iloc[0]
115
+ feats["file"] = df["file"].iloc[0]
116
+ return feats
117
+
118
+
119
+ def create_feature_dataset(df_all, fs=50):
120
+ features = [extract_essential_features(g, fs) for _, g in df_all.groupby("file")]
121
+ return pd.DataFrame(features)
122
+
123
+ # ======================== VISUALIZATION FUNCTIONS ========================
124
+ def plot_pca_clustering(df_features, X_scaled, model):
125
+ """
126
+ Plot PCA clustering visualization
127
+
128
+ Parameters:
129
+ - df_features: DataFrame ของคุณลักษณะ
130
+ - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
131
+ - model: โมเดลที่ฝึกแล้ว
132
+
133
+ Returns:
134
+ - pca: PCA object
135
+ - df_plot: DataFrame สำหรับ plotting
136
+ """
137
+ pca = PCA(n_components=2)
138
+ X_pca = pca.fit_transform(X_scaled)
139
+
140
+ # สร้าง DataFrame สำหรับ plotting
141
+ df_plot = df_features.copy()
142
+ df_plot["pca1"] = X_pca[:, 0]
143
+ df_plot["pca2"] = X_pca[:, 1]
144
+ df_plot["pred"] = model.predict(X_scaled)
145
+
146
+ plt.figure(figsize=(8, 6))
147
+ sns.scatterplot(
148
+ data=df_plot,
149
+ x="pca1", y="pca2",
150
+ hue="label", style="pred",
151
+ palette={"normal": "#4CAF50", "pd": "#E91E63"},
152
+ s=90, alpha=0.9
153
+ )
154
+ plt.title("🧩 PCA Clustering Visualization (PD vs Normal)", fontsize=14)
155
+ plt.xlabel("PCA 1")
156
+ plt.ylabel("PCA 2")
157
+ plt.legend(title="Label / Prediction")
158
+ plt.show()
159
+
160
+ return pca, df_plot
161
+
162
+ def plot_pca_biplot(df_features, X_scaled, X, pca=None):
163
+ """
164
+ Plot PCA biplot with feature loading vectors
165
+
166
+ Parameters:
167
+ - df_features: DataFrame ของคุณลักษณะ
168
+ - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
169
+ - X: ข้อมูลคุณลักษณะดั้งเดิม
170
+ - pca: PCA object (ถ้ามี)
171
+
172
+ Returns:
173
+ - loadings: DataFrame ของ loading vectors
174
+ - df_plot: DataFrame สำหรับ plotting
175
+ """
176
+ if pca is None:
177
+ pca = PCA(n_components=2)
178
+ X_pca = pca.fit_transform(X_scaled)
179
+ else:
180
+ X_pca = pca.transform(X_scaled)
181
+
182
+ # สร้าง DataFrame สำหรับ plotting
183
+ df_plot = df_features.copy()
184
+ df_plot["pca1"] = X_pca[:, 0]
185
+ df_plot["pca2"] = X_pca[:, 1]
186
+
187
+ loadings = pd.DataFrame(
188
+ pca.components_.T,
189
+ columns=['PCA1', 'PCA2'],
190
+ index=X.columns
191
+ )
192
+
193
+ # แสดง top feature ที่มีผลต่อ PCA1 และ PCA2
194
+ print("\n📊 Top 10 features influencing PCA1:")
195
+ print(loadings['PCA1'].sort_values(ascending=False).head(10))
196
+ print("\n📊 Top 10 features influencing PCA2:")
197
+ print(loadings['PCA2'].sort_values(ascending=False).head(10))
198
+
199
+ # Plot loading vectors (Biplot)
200
+ plt.figure(figsize=(10, 8))
201
+ sns.scatterplot(
202
+ data=df_plot,
203
+ x="pca1", y="pca2",
204
+ hue="label",
205
+ palette={"normal": "#4CAF50", "pd": "#E91E63"},
206
+ s=80, alpha=0.9
207
+ )
208
+
209
+ # เพิ่ม loading vectors
210
+ for i in range(len(loadings)):
211
+ plt.arrow(0, 0, loadings.PCA1[i]*10, loadings.PCA2[i]*10,
212
+ color='gray', alpha=0.5, head_width=0.3)
213
+ plt.text(loadings.PCA1[i]*11, loadings.PCA2[i]*11,
214
+ loadings.index[i], fontsize=8, color='black')
215
+
216
+ plt.title("📈 PCA Biplot: Feature Loading Direction", fontsize=13)
217
+ plt.xlabel("PCA 1")
218
+ plt.ylabel("PCA 2")
219
+ plt.grid(True, alpha=0.3)
220
+ plt.show()
221
+
222
+ return loadings, df_plot
223
+
224
+ def plot_roc_curve(y_true, y_proba, model_name="Random Forest"):
225
+ """
226
+ Plot ROC curve
227
+
228
+ Parameters:
229
+ - y_true: ค่าเป้าหมายจริง
230
+ - y_proba: ความน่าจะเป็นที่ทำนาย
231
+ - model_name: ชื่อโมเดล
232
+
233
+ Returns:
234
+ - roc_auc: ROC AUC score
235
+ - fpr: False Positive Rates
236
+ - tpr: True Positive Rates
237
+ """
238
+ fpr, tpr, thresholds = roc_curve(y_true, y_proba)
239
+ roc_auc = roc_auc_score(y_true, y_proba)
240
+
241
+ plt.figure(figsize=(6, 6))
242
+ plt.plot(fpr, tpr, color="#E91E63", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
243
+ plt.plot([0, 1], [0, 1], color="gray", linestyle="--")
244
+ plt.xlabel("False Positive Rate")
245
+ plt.ylabel("True Positive Rate")
246
+ plt.title(f"🧩 ROC Curve – {model_name} (PD vs Normal)")
247
+ plt.legend(loc="lower right")
248
+ plt.grid(True, alpha=0.3)
249
+ plt.show()
250
+
251
+ return roc_auc, fpr, tpr
252
+
253
+ def plot_shap_analysis(model, X_scaled, X, plot_type="both"):
254
+ """
255
+ SHAP analysis และ visualization
256
+
257
+ Parameters:
258
+ - model: โมเดลที่ฝึกแล้ว
259
+ - X_scaled: ข้อมูลคุณลักษณะที่ผ่านการ scaling
260
+ - X: ข้อมูลคุณลักษณะดั้งเดิม
261
+ - plot_type: ประเภท plot ("bar", "beeswarm", "both")
262
+
263
+ Returns:
264
+ - explainer: SHAP explainer
265
+ - shap_values: SHAP values
266
+ """
267
+ explainer = shap.TreeExplainer(model)
268
+ shap_values = explainer.shap_values(X_scaled)
269
+
270
+ if plot_type in ["bar", "both"]:
271
+ shap.summary_plot(shap_values[1], X, plot_type="bar", show=False)
272
+ plt.title("SHAP Feature Importance (Bar Plot)")
273
+ plt.tight_layout()
274
+ plt.show()
275
+
276
+ if plot_type in ["beeswarm", "both"]:
277
+ shap.summary_plot(shap_values[1], X, show=False)
278
+ plt.title("SHAP Feature Importance (Beeswarm Plot)")
279
+ plt.tight_layout()
280
+ plt.show()
281
+
282
+ return explainer, shap_values
283
+
284
+
285
+ # ======================== MODEL TRAINING ========================
286
+ def train_random_forest(X, y, n_estimators=300, max_depth=6, random_state=42):
287
+ """ฝึก RandomForest พร้อมจัดการ NaN ใน y"""
288
+ df_tmp = pd.DataFrame(X).copy()
289
+ df_tmp["label"] = y
290
+ df_tmp = df_tmp.dropna(subset=["label"])
291
+ df_tmp = df_tmp.dropna(axis=0, how="any")
292
+
293
+ y_clean = df_tmp["label"].values
294
+ X_clean = df_tmp.drop(columns=["label"]).values
295
+
296
+ scaler = StandardScaler()
297
+ X_scaled = scaler.fit_transform(X_clean)
298
+
299
+ model = RandomForestClassifier(
300
+ n_estimators=n_estimators,
301
+ max_depth=max_depth,
302
+ random_state=random_state,
303
+ )
304
+ model.fit(X_scaled, y_clean)
305
+ print(f"✅ Training complete ({len(y_clean)} samples used)")
306
+ return model, scaler, X_scaled
307
+
308
+
309
+ def evaluate_model(model, X_scaled, y_true):
310
+ y_pred = model.predict(X_scaled)
311
+ y_proba = model.predict_proba(X_scaled)[:, 1]
312
+
313
+ print("\nConfusion Matrix:")
314
+ print(confusion_matrix(y_true, y_pred))
315
+ print("\nClassification Report:")
316
+ print(classification_report(y_true, y_pred, target_names=["Normal", "PD"]))
317
+
318
+ return y_pred, y_proba
319
+
320
+
321
+ # ======================== SAVE MODEL ========================
322
+ def save_rf_model(model, scaler, feature_names, base_path):
323
+ model_dict = {
324
+ "model": model,
325
+ "scaler": scaler,
326
+ "features": feature_names
327
+ }
328
+ save_path = os.path.join(base_path, "tremor_rf_model.joblib")
329
+ dump(model_dict, save_path)
330
+ print(f"💾 Model saved to {save_path}")
331
+ return save_path