TunisianCoder commited on
Commit
2504dcc
Β·
verified Β·
1 Parent(s): 81d2721

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Space β€” Sleep Stage Classification
3
+ ================================================
4
+ Gradio app that serves the pre-trained CNN model for inference.
5
+ Callable from any frontend via the Gradio API.
6
+
7
+ Space URL: https://<your-username>-sleep-stage-classifier.hf.space
8
+ """
9
+
10
+ import io
11
+ import os
12
+ import json
13
+ import numpy as np
14
+ import pandas as pd
15
+ import gradio as gr
16
+ import torch
17
+ import torch.nn as nn
18
+ from collections import Counter
19
+
20
+ # ────────────────────────────────────────────────────────────────
21
+ # Constants
22
+ # ────────────────────────────────────────────────────────────────
23
+ SFREQ = 100
24
+ EPOCH_SAMPLES = 3000 # 30 seconds Γ— 100 Hz
25
+ STAGES = ["Wake", "N1", "N2", "N3", "N4", "REM"]
26
+ MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sleep_stage_cnn.pth")
27
+
28
+
29
+ # ────────────────────────────────────────────────────────────────
30
+ # Model Definition (must match training architecture exactly)
31
+ # ────────────────────────────────────────────────────────────────
32
+
33
+ class SleepStageCNN(nn.Module):
34
+ """
35
+ 1D Convolutional Neural Network for Sleep Stage Classification.
36
+ Architecture matches the training notebook.
37
+ """
38
+
39
+ def __init__(self, n_channels=1, n_classes=6):
40
+ super().__init__()
41
+ self.network = nn.Sequential(
42
+ # Block 1: large receptive field for slow-wave features
43
+ nn.Conv1d(n_channels, 32, kernel_size=50, stride=6),
44
+ nn.BatchNorm1d(32),
45
+ nn.ReLU(),
46
+ nn.MaxPool1d(8),
47
+
48
+ # Block 2: finer feature extraction
49
+ nn.Conv1d(32, 64, kernel_size=8),
50
+ nn.BatchNorm1d(64),
51
+ nn.ReLU(),
52
+ nn.MaxPool1d(8),
53
+
54
+ # Classifier head
55
+ nn.Flatten(),
56
+ nn.Linear(64 * 6, 128),
57
+ nn.ReLU(),
58
+ nn.Dropout(0.5),
59
+ nn.Linear(128, n_classes),
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.network(x)
64
+
65
+
66
+ # ────────────────────────────────────────────────────────────────
67
+ # Load Model at startup
68
+ # ────────────────────────────────────────────────────────────────
69
+
70
+ device = torch.device("cpu")
71
+ model = SleepStageCNN(n_channels=1, n_classes=6)
72
+
73
+ if os.path.exists(MODEL_PATH):
74
+ checkpoint = torch.load(
75
+ MODEL_PATH, map_location=device, weights_only=False
76
+ )
77
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+ else:
80
+ model.load_state_dict(checkpoint)
81
+ model.eval().to(device)
82
+ print(f"βœ… Model loaded from {MODEL_PATH}")
83
+ else:
84
+ raise FileNotFoundError(
85
+ f"Model file not found at {MODEL_PATH}. "
86
+ "Upload sleep_stage_cnn.pth to this Space."
87
+ )
88
+
89
+
90
+ # ────────────────────────────────────────────────────────────────
91
+ # Inference Function
92
+ # ────────────────────────────────────────────────────────────────
93
+
94
+ def classify_eeg(signal: np.ndarray) -> dict:
95
+ """
96
+ Run inference on a 1D EEG signal.
97
+
98
+ Parameters
99
+ ----------
100
+ signal : np.ndarray
101
+ Raw EEG data (1D array, assumed 100 Hz sampling rate).
102
+
103
+ Returns
104
+ -------
105
+ dict with keys:
106
+ - epochs: list of {epoch, stage, confidence}
107
+ - summary: dict of stage β†’ "count (percentage%)"
108
+ """
109
+ if len(signal) < EPOCH_SAMPLES:
110
+ return {
111
+ "error": (
112
+ f"Signal too short. Need at least {EPOCH_SAMPLES} samples "
113
+ f"(30s at 100 Hz), got {len(signal)}."
114
+ )
115
+ }
116
+
117
+ predictions = []
118
+ for i in range(0, len(signal) - EPOCH_SAMPLES + 1, EPOCH_SAMPLES):
119
+ epoch = signal[i: i + EPOCH_SAMPLES]
120
+
121
+ # Z-score normalize
122
+ mean = epoch.mean()
123
+ std = epoch.std()
124
+ if std == 0:
125
+ std = 1.0
126
+ epoch_norm = (epoch - mean) / std
127
+
128
+ # Forward pass
129
+ x = torch.tensor(
130
+ epoch_norm, dtype=torch.float32
131
+ ).unsqueeze(0).unsqueeze(0).to(device)
132
+
133
+ with torch.no_grad():
134
+ logits = model(x)
135
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
136
+ pred_idx = int(logits.argmax().item())
137
+
138
+ predictions.append({
139
+ "epoch": len(predictions) + 1,
140
+ "stage": STAGES[pred_idx],
141
+ "confidence": round(float(max(probs)), 4),
142
+ "probabilities": {
143
+ STAGES[j]: round(float(probs[j]), 4)
144
+ for j in range(len(STAGES))
145
+ },
146
+ })
147
+
148
+ # Summary statistics
149
+ counts = Counter(p["stage"] for p in predictions)
150
+ total = len(predictions)
151
+
152
+ return {
153
+ "epochs": predictions,
154
+ "summary": {
155
+ stage: {
156
+ "count": counts.get(stage, 0),
157
+ "percentage": round(counts.get(stage, 0) / total * 100, 1)
158
+ }
159
+ for stage in STAGES
160
+ },
161
+ }
162
+
163
+
164
+ # ────────────────────────────────────────────────────────────────
165
+ # File Processor (called by Gradio UI)
166
+ # ────────────────────────────────────────────────────────────────
167
+
168
+ def process_file(file) -> tuple:
169
+ """
170
+ Process uploaded EEG file and return readable results + raw JSON.
171
+
172
+ Parameters
173
+ ----------
174
+ file : file-like or str path
175
+ Uploaded CSV / TXT / NPY file.
176
+
177
+ Returns
178
+ -------
179
+ (text_output, json_output)
180
+ """
181
+ if file is None:
182
+ return "⚠️ Please upload a file.", None
183
+
184
+ try:
185
+ # Determine file type and load signal
186
+ name = file.name.lower() if hasattr(file, "name") else str(file).lower()
187
+
188
+ if name.endswith(".npy"):
189
+ signal = np.load(file)
190
+ if signal.ndim > 1:
191
+ signal = signal.flatten()
192
+ else:
193
+ # CSV or TXT β€” first column
194
+ df = pd.read_csv(file, header=None, sep=None, engine="python")
195
+ signal = df.iloc[:, 0].values.astype(np.float64)
196
+
197
+ # Run inference
198
+ result = classify_eeg(signal)
199
+
200
+ if "error" in result:
201
+ return f"❌ {result['error']}", None
202
+
203
+ # Build readable text output
204
+ lines = []
205
+ lines.append(f"πŸ“Š Total epochs classified: {len(result['epochs'])}")
206
+ lines.append("")
207
+ lines.append("πŸ“‹ Stage Distribution:")
208
+ lines.append("-" * 40)
209
+ for stage, stats in result["summary"].items():
210
+ bar = "β–ˆ" * int(stats["percentage"] / 2)
211
+ lines.append(f" {stage:6s}: {stats['count']:4d} ({stats['percentage']:5.1f}%) {bar}")
212
+
213
+ lines.append("")
214
+ lines.append("πŸ“ Epoch Details (first 20):")
215
+ lines.append("-" * 40)
216
+ for ep in result["epochs"][:20]:
217
+ lines.append(
218
+ f" Epoch {ep['epoch']:>3d}: {ep['stage']:5s} "
219
+ f"confidence {ep['confidence']*100:.1f}%"
220
+ )
221
+
222
+ text_output = "\n".join(lines)
223
+ json_output = result # Gradio will auto-serialize to JSON
224
+
225
+ return text_output, json_output
226
+
227
+ except Exception as e:
228
+ return f"❌ Error: {str(e)}", None
229
+
230
+
231
+ # ────────────────────────────────────────────────────────────────
232
+ # Gradio Interface
233
+ # ────────────────────────────────────────────────────────────────
234
+
235
+ with gr.Blocks(
236
+ title="Sleep Stage Classifier",
237
+ theme=gr.themes.Soft(
238
+ primary_hue="blue",
239
+ secondary_hue="slate",
240
+ ),
241
+ ) as demo:
242
+
243
+ gr.Markdown(
244
+ """
245
+ # 😴 Sleep Stage Classification
246
+
247
+ Upload a **CSV**, **TXT**, or **NPY** file containing raw EEG signal data.
248
+ The model assumes a **100 Hz sampling rate** and classifies the signal
249
+ into 30-second epochs.
250
+
251
+ | Stage | Description |
252
+ |-------|-------------|
253
+ | **Wake** | Awake, eyes open/closed |
254
+ | **N1** | Light sleep, transition |
255
+ | **N2** | Deeper sleep, spindles + K-complexes |
256
+ | **N3** | Slow-wave sleep (deep) |
257
+ | **N4** | Very deep slow-wave sleep |
258
+ | **REM** | Rapid eye movement (dreaming) |
259
+ """
260
+ )
261
+
262
+ with gr.Row():
263
+ with gr.Column(scale=1):
264
+ file_input = gr.File(
265
+ label="Upload EEG file",
266
+ file_types=[".csv", ".txt", ".npy"],
267
+ )
268
+ btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
269
+
270
+ gr.Markdown("πŸ’‘ **Tip:** Upload a single-column CSV with EEG amplitude values (100 Hz).")
271
+
272
+ with gr.Column(scale=2):
273
+ text_output = gr.Textbox(
274
+ label="Results",
275
+ lines=20,
276
+ interactive=False,
277
+ )
278
+ json_output = gr.JSON(
279
+ label="Raw JSON (for API integration)",
280
+ )
281
+
282
+ btn.click(
283
+ fn=process_file,
284
+ inputs=[file_input],
285
+ outputs=[text_output, json_output],
286
+ )
287
+
288
+ gr.Markdown(
289
+ """
290
+ ---
291
+ ### πŸ”Œ API Access
292
+
293
+ You can call this Space programmatically from any frontend:
294
+
295
+ ```bash
296
+ pip install gradio_client
297
+ ```
298
+
299
+ ```python
300
+ from gradio_client import Client
301
+
302
+ client = Client("<your-username>/sleep-stage-classifier")
303
+ result = client.predict(file="path/to/eeg.csv")
304
+ print(result)
305
+ ```
306
+
307
+ Or from JavaScript in your Lovable app:
308
+
309
+ ```javascript
310
+ import { Client } from "@gradio/client";
311
+
312
+ const client = await Client.connect(
313
+ "https://<your-username>-sleep-stage-classifier.hf.space"
314
+ );
315
+ const result = await client.predict("/predict", { file: yourFile });
316
+ ```
317
+ """
318
+ )
319
+
320
+
321
+ # ────────────────────────────────────────────────────────────────
322
+ # Launch
323
+ # ────────────────────────────────────────────────────────────────
324
+
325
+ if __name__ == "__main__":
326
+ demo.launch()