TunisianCoder commited on
Commit
c01118e
Β·
verified Β·
1 Parent(s): 567c441

Upload 5 files

Browse files
Files changed (5) hide show
  1. DEPLOY-GUIDE.md +210 -0
  2. README.md +39 -6
  3. app.py +330 -0
  4. requirements.txt +5 -0
  5. sleep_stage_cnn.pth +3 -0
DEPLOY-GUIDE.md ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploy to Hugging Face Space
2
+
3
+ ## Step 1: Create the Space
4
+
5
+ 1. Go to **https://huggingface.co/spaces**
6
+ 2. Click **Create new Space**
7
+ 3. Fill in:
8
+ - **Space name**: `sleep-stage-classifier`
9
+ - **License**: MIT
10
+ - **SDK**: **Gradio**
11
+ - **Visibility**: Public
12
+ 4. Click **Create Space**
13
+
14
+ ---
15
+
16
+ ## Step 2: Upload Files
17
+
18
+ In your new Space, go to the **Files** tab β†’ **Add file** β†’ **Upload files**.
19
+
20
+ Upload these **3 files** from the `hf-space/` folder:
21
+
22
+ | File | Purpose |
23
+ |------|---------|
24
+ | `requirements.txt` | Python dependencies |
25
+ | `app.py` | Gradio web app + model inference |
26
+ | `sleep_stage_cnn.pth` | Your trained model weights |
27
+
28
+ > ⚠️ Make sure you upload **all 3 files**. The Space won't work without the `.pth` model file.
29
+
30
+ ---
31
+
32
+ ## Step 3: Wait for Build
33
+
34
+ The Space will automatically build. You'll see a **Building** status, then a green βœ… when ready.
35
+
36
+ This takes ~2-5 minutes (downloading PyTorch + other deps).
37
+
38
+ ---
39
+
40
+ ## Step 4: Test It
41
+
42
+ Once the green checkmark appears, the Space URL is:
43
+
44
+ ```
45
+ https://<your-username>-sleep-stage-classifier.hf.space
46
+ ```
47
+
48
+ Upload a CSV file with EEG data to test classification.
49
+
50
+ ---
51
+
52
+ # Integrate with Lovable
53
+
54
+ ## Method 1: Gradio Client (Recommended)
55
+
56
+ In your Lovable project, install the Gradio JS client:
57
+
58
+ ```bash
59
+ npm install @gradio/client
60
+ ```
61
+
62
+ Then add this code to your component:
63
+
64
+ ```tsx
65
+ import { Client } from "@gradio/client";
66
+
67
+ // Replace with your actual Space URL
68
+ const SPACE_URL = "https://<your-username>-sleep-stage-classifier.hf.space";
69
+
70
+ async function classifyEEG(file: File) {
71
+ const client = await Client.connect(SPACE_URL);
72
+ const result = await client.predict("/predict", { file });
73
+ return result.data; // [textOutput, jsonOutput]
74
+ }
75
+ ```
76
+
77
+ ### Full Example Component
78
+
79
+ ```tsx
80
+ import { useState } from "react";
81
+ import { Client } from "@gradio/client";
82
+
83
+ const SPACE_URL = "https://<your-username>-sleep-stage-classifier.hf.space";
84
+
85
+ export default function SleepClassifier() {
86
+ const [loading, setLoading] = useState(false);
87
+ const [results, setResults] = useState<any>(null);
88
+
89
+ const handleUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
90
+ const file = e.target.files?.[0];
91
+ if (!file) return;
92
+
93
+ setLoading(true);
94
+ try {
95
+ const client = await Client.connect(SPACE_URL);
96
+ const [textOutput, jsonOutput] = await client.predict("/predict", { file });
97
+ setResults(jsonOutput);
98
+ } catch (err) {
99
+ console.error("Classification failed:", err);
100
+ }
101
+ setLoading(false);
102
+ };
103
+
104
+ return (
105
+ <div className="p-6 max-w-2xl mx-auto">
106
+ <h1 className="text-2xl font-bold mb-4">😴 Sleep Stage Classifier</h1>
107
+
108
+ <input type="file" accept=".csv,.txt,.npy" onChange={handleUpload} />
109
+
110
+ {loading && <p className="mt-4">Classifying...</p>}
111
+
112
+ {results && (
113
+ <div className="mt-4">
114
+ <h2 className="text-lg font-semibold">Results</h2>
115
+ <p>{results.epochs.length} epochs classified</p>
116
+
117
+ <div className="mt-2 space-y-1">
118
+ {Object.entries(results.summary).map(([stage, stats]: [string, any]) => (
119
+ <div key={stage} className="flex items-center gap-2">
120
+ <span className="w-12">{stage}</span>
121
+ <span>{stats.count} ({stats.percentage}%)</span>
122
+ <div className="flex-1 bg-gray-200 rounded h-4">
123
+ <div
124
+ className="bg-blue-500 h-4 rounded"
125
+ style={{ width: `${stats.percentage}%` }}
126
+ />
127
+ </div>
128
+ </div>
129
+ ))}
130
+ </div>
131
+ </div>
132
+ )}
133
+ </div>
134
+ );
135
+ }
136
+ ```
137
+
138
+ ---
139
+
140
+ ## Method 2: Direct HTTP Fetch
141
+
142
+ If you prefer not to install `@gradio/client`:
143
+
144
+ ```typescript
145
+ async function classifyEEG(file: File) {
146
+ const formData = new FormData();
147
+ formData.append("files", file);
148
+
149
+ const res = await fetch(
150
+ `${SPACE_URL}/gradio_api/call/predict`,
151
+ {
152
+ method: "POST",
153
+ body: formData,
154
+ }
155
+ );
156
+
157
+ const result = await res.json();
158
+ return result;
159
+ }
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Expected JSON Response Format
165
+
166
+ ```json
167
+ {
168
+ "epochs": [
169
+ {
170
+ "epoch": 1,
171
+ "stage": "Wake",
172
+ "confidence": 0.92,
173
+ "probabilities": {
174
+ "Wake": 0.92,
175
+ "N1": 0.03,
176
+ "N2": 0.02,
177
+ "N3": 0.01,
178
+ "N4": 0.01,
179
+ "REM": 0.01
180
+ }
181
+ }
182
+ ],
183
+ "summary": {
184
+ "Wake": { "count": 50, "percentage": 45.5 },
185
+ "N1": { "count": 10, "percentage": 9.1 },
186
+ "N2": { "count": 25, "percentage": 22.7 },
187
+ "N3": { "count": 15, "percentage": 13.6 },
188
+ "N4": { "count": 5, "percentage": 4.5 },
189
+ "REM": { "count": 5, "percentage": 4.5 }
190
+ }
191
+ }
192
+ ```
193
+
194
+ ---
195
+
196
+ ## File Format
197
+
198
+ Your input file should be:
199
+ - **CSV/TXT**: Single column of numbers (EEG amplitude values), no header needed
200
+ - **NPY**: NumPy array, 1D
201
+ - **Sampling rate**: Assumed 100 Hz
202
+ - **Minimum length**: 3000 samples (30 seconds)
203
+
204
+ Example CSV:
205
+ ```csv
206
+ 0.023
207
+ -0.015
208
+ 0.042
209
+ ...
210
+ ```
README.md CHANGED
@@ -1,13 +1,46 @@
1
  ---
2
- title: EEG Sleep
3
- emoji: πŸ“š
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Sleep Stage Classifier
3
+ emoji: 😴
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Classify sleep stages from raw EEG signals
12
  ---
13
 
14
+ # 😴 Sleep Stage Classification
15
+
16
+ Upload a CSV, TXT, or NPY file containing raw EEG signal data (100 Hz sampling rate).
17
+ The model classifies the signal into 30-second epochs across 6 sleep stages: **Wake, N1, N2, N3, N4, REM**.
18
+
19
+ ## Model Architecture
20
+
21
+ - **Type**: 1D Convolutional Neural Network
22
+ - **Framework**: PyTorch
23
+ - **Input**: Single-channel EEG, 3000 samples per epoch (30s at 100 Hz)
24
+ - **Output**: 6-class classification logits β†’ softmax probabilities
25
+
26
+ ## API Usage
27
+
28
+ ```python
29
+ from gradio_client import Client
30
+
31
+ client = Client("<your-username>/sleep-stage-classifier")
32
+ result = client.predict(file="path/to/eeg.csv")
33
+ print(result)
34
+ ```
35
+
36
+ ## Lovable / Frontend Integration
37
+
38
+ ```javascript
39
+ import { Client } from "@gradio/client";
40
+
41
+ const client = await Client.connect(
42
+ "https://<your-username>-sleep-stage-classifier.hf.space"
43
+ );
44
+ const result = await client.predict("/predict", { file: yourFile });
45
+ console.log(result.data);
46
+ ```
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Space β€” Sleep Stage Classification
3
+ ================================================
4
+ Gradio app that serves the pre-trained CNN model for inference.
5
+ Callable from any frontend via the Gradio API.
6
+
7
+ Space URL: https://<your-username>-sleep-stage-classifier.hf.space
8
+ """
9
+
10
+ import io
11
+ import os
12
+ import json
13
+ import numpy as np
14
+ import pandas as pd
15
+ import gradio as gr
16
+ import torch
17
+ import torch.nn as nn
18
+ from collections import Counter
19
+
20
+ # ────────────────────────────────────────────────────────────────
21
+ # Constants
22
+ # ────────────────────────────────────────────────────────────────
23
+ SFREQ = 100
24
+ EPOCH_SAMPLES = 3000 # 30 seconds Γ— 100 Hz
25
+ STAGES = ["Wake", "N1", "N2", "N3", "N4", "REM"]
26
+ MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sleep_stage_cnn.pth")
27
+
28
+
29
+ # ────────────────────────────────────────────────────────────────
30
+ # Model Definition (must match training architecture exactly)
31
+ # ────────────────────────────────────────────────────────────────
32
+
33
+ class SleepStageCNN(nn.Module):
34
+ """
35
+ 1D Convolutional Neural Network for Sleep Stage Classification.
36
+ Architecture matches the training notebook.
37
+ """
38
+
39
+ def __init__(self, n_channels=1, n_classes=6):
40
+ super().__init__()
41
+ self.network = nn.Sequential(
42
+ # Block 1: large receptive field for slow-wave features
43
+ nn.Conv1d(n_channels, 32, kernel_size=50, stride=6),
44
+ nn.BatchNorm1d(32),
45
+ nn.ReLU(),
46
+ nn.MaxPool1d(8),
47
+
48
+ # Block 2: finer feature extraction
49
+ nn.Conv1d(32, 64, kernel_size=8),
50
+ nn.BatchNorm1d(64),
51
+ nn.ReLU(),
52
+ nn.MaxPool1d(8),
53
+
54
+ # Classifier head
55
+ nn.Flatten(),
56
+ nn.Linear(64 * 6, 128),
57
+ nn.ReLU(),
58
+ nn.Dropout(0.5),
59
+ nn.Linear(128, n_classes),
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.network(x)
64
+
65
+
66
+ # ────────────────────────────────────────────────────────────────
67
+ # Load Model at startup
68
+ # ────────────────────────────────────────────────────────────────
69
+
70
+ device = torch.device("cpu")
71
+ model = SleepStageCNN(n_channels=1, n_classes=6)
72
+
73
+ if os.path.exists(MODEL_PATH):
74
+ checkpoint = torch.load(
75
+ MODEL_PATH, map_location=device, weights_only=False
76
+ )
77
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+ else:
80
+ model.load_state_dict(checkpoint)
81
+ model.eval().to(device)
82
+ print(f"βœ… Model loaded from {MODEL_PATH}")
83
+ else:
84
+ raise FileNotFoundError(
85
+ f"Model file not found at {MODEL_PATH}. "
86
+ "Upload sleep_stage_cnn.pth to this Space."
87
+ )
88
+
89
+
90
+ # ────────────────────────────────────────────────────────────────
91
+ # Inference Function
92
+ # ────────────────────────────────────────────────────────────────
93
+
94
+ def classify_eeg(signal: np.ndarray) -> dict:
95
+ """
96
+ Run inference on a 1D EEG signal.
97
+
98
+ Parameters
99
+ ----------
100
+ signal : np.ndarray
101
+ Raw EEG data (1D array, assumed 100 Hz sampling rate).
102
+
103
+ Returns
104
+ -------
105
+ dict with keys:
106
+ - epochs: list of {epoch, stage, confidence}
107
+ - summary: dict of stage β†’ "count (percentage%)"
108
+ """
109
+ if len(signal) < EPOCH_SAMPLES:
110
+ return {
111
+ "error": (
112
+ f"Signal too short. Need at least {EPOCH_SAMPLES} samples "
113
+ f"(30s at 100 Hz), got {len(signal)}."
114
+ )
115
+ }
116
+
117
+ predictions = []
118
+ for i in range(0, len(signal) - EPOCH_SAMPLES + 1, EPOCH_SAMPLES):
119
+ epoch = signal[i: i + EPOCH_SAMPLES]
120
+
121
+ # Z-score normalize
122
+ mean = epoch.mean()
123
+ std = epoch.std()
124
+ if std == 0:
125
+ std = 1.0
126
+ epoch_norm = (epoch - mean) / std
127
+
128
+ # Forward pass
129
+ x = torch.tensor(
130
+ epoch_norm, dtype=torch.float32
131
+ ).unsqueeze(0).unsqueeze(0).to(device)
132
+
133
+ with torch.no_grad():
134
+ logits = model(x)
135
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
136
+ pred_idx = int(logits.argmax().item())
137
+
138
+ predictions.append({
139
+ "epoch": len(predictions) + 1,
140
+ "stage": STAGES[pred_idx],
141
+ "confidence": round(float(max(probs)), 4),
142
+ "probabilities": {
143
+ STAGES[j]: round(float(probs[j]), 4)
144
+ for j in range(len(STAGES))
145
+ },
146
+ })
147
+
148
+ # Summary statistics
149
+ counts = Counter(p["stage"] for p in predictions)
150
+ total = len(predictions)
151
+
152
+ return {
153
+ "epochs": predictions,
154
+ "summary": {
155
+ stage: {
156
+ "count": counts.get(stage, 0),
157
+ "percentage": round(counts.get(stage, 0) / total * 100, 1)
158
+ }
159
+ for stage in STAGES
160
+ },
161
+ }
162
+
163
+
164
+ # ────────────────────────────────────────────────────────────────
165
+ # File Processor (called by Gradio UI)
166
+ # ────────────────────────────────────────────────────────────────
167
+
168
+ def process_file(file) -> tuple:
169
+ """
170
+ Process uploaded EEG file and return readable results + raw JSON.
171
+
172
+ Parameters
173
+ ----------
174
+ file : file-like or str path
175
+ Uploaded CSV / TXT / NPY file.
176
+
177
+ Returns
178
+ -------
179
+ (text_output, json_output)
180
+ """
181
+ if file is None:
182
+ return "⚠️ Please upload a file.", None
183
+
184
+ try:
185
+ # Determine file type and load signal
186
+ name = file.name.lower() if hasattr(file, "name") else str(file).lower()
187
+
188
+ if name.endswith(".npy"):
189
+ signal = np.load(file)
190
+ if signal.ndim > 1:
191
+ signal = signal.flatten()
192
+ else:
193
+ # CSV or TXT β€” first column
194
+ df = pd.read_csv(file, header=None, sep=None, engine="python")
195
+ signal = df.iloc[:, 0].values.astype(np.float64)
196
+
197
+ # Run inference
198
+ result = classify_eeg(signal)
199
+
200
+ if "error" in result:
201
+ return f"❌ {result['error']}", None
202
+
203
+ # Build readable text output
204
+ lines = []
205
+ lines.append(f"πŸ“Š Total epochs classified: {len(result['epochs'])}")
206
+ lines.append("")
207
+ lines.append("πŸ“‹ Stage Distribution:")
208
+ lines.append("-" * 40)
209
+ for stage, stats in result["summary"].items():
210
+ bar = "β–ˆ" * int(stats["percentage"] / 2)
211
+ lines.append(f" {stage:6s}: {stats['count']:4d} ({stats['percentage']:5.1f}%) {bar}")
212
+
213
+ lines.append("")
214
+ lines.append("πŸ“ Epoch Details (first 20):")
215
+ lines.append("-" * 40)
216
+ for ep in result["epochs"][:20]:
217
+ lines.append(
218
+ f" Epoch {ep['epoch']:>3d}: {ep['stage']:5s} "
219
+ f"confidence {ep['confidence']*100:.1f}%"
220
+ )
221
+
222
+ text_output = "\n".join(lines)
223
+ json_output = result # Gradio will auto-serialize to JSON
224
+
225
+ return text_output, json_output
226
+
227
+ except Exception as e:
228
+ return f"❌ Error: {str(e)}", None
229
+
230
+
231
+ # ────────────────────────────────────────────────────────────────
232
+ # Gradio Interface
233
+ # ────────────────────────────────────────────────────────────────
234
+
235
+ with gr.Blocks(
236
+ title="Sleep Stage Classifier",
237
+ theme=gr.themes.Soft(
238
+ primary_hue="blue",
239
+ secondary_hue="slate",
240
+ ),
241
+ ) as demo:
242
+
243
+ gr.Markdown(
244
+ """
245
+ # 😴 Sleep Stage Classification
246
+
247
+ Upload a **CSV**, **TXT**, or **NPY** file containing raw EEG signal data.
248
+ The model assumes a **100 Hz sampling rate** and classifies the signal
249
+ into 30-second epochs.
250
+
251
+ | Stage | Description |
252
+ |-------|-------------|
253
+ | **Wake** | Awake, eyes open/closed |
254
+ | **N1** | Light sleep, transition |
255
+ | **N2** | Deeper sleep, spindles + K-complexes |
256
+ | **N3** | Slow-wave sleep (deep) |
257
+ | **N4** | Very deep slow-wave sleep |
258
+ | **REM** | Rapid eye movement (dreaming) |
259
+ """
260
+ )
261
+
262
+ with gr.Row():
263
+ with gr.Column(scale=1):
264
+ file_input = gr.File(
265
+ label="Upload EEG file",
266
+ file_types=[".csv", ".txt", ".npy"],
267
+ )
268
+ btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
269
+
270
+ gr.Examples(
271
+ label="Tip",
272
+ examples=["Upload a single-column CSV with EEG amplitude values"],
273
+ inputs=[],
274
+ )
275
+
276
+ with gr.Column(scale=2):
277
+ text_output = gr.Textbox(
278
+ label="Results",
279
+ lines=20,
280
+ interactive=False,
281
+ )
282
+ json_output = gr.JSON(
283
+ label="Raw JSON (for API integration)",
284
+ )
285
+
286
+ btn.click(
287
+ fn=process_file,
288
+ inputs=[file_input],
289
+ outputs=[text_output, json_output],
290
+ )
291
+
292
+ gr.Markdown(
293
+ """
294
+ ---
295
+ ### πŸ”Œ API Access
296
+
297
+ You can call this Space programmatically from any frontend:
298
+
299
+ ```bash
300
+ pip install gradio_client
301
+ ```
302
+
303
+ ```python
304
+ from gradio_client import Client
305
+
306
+ client = Client("<your-username>/sleep-stage-classifier")
307
+ result = client.predict(file="path/to/eeg.csv")
308
+ print(result)
309
+ ```
310
+
311
+ Or from JavaScript in your Lovable app:
312
+
313
+ ```javascript
314
+ import { Client } from "@gradio/client";
315
+
316
+ const client = await Client.connect(
317
+ "https://<your-username>-sleep-stage-classifier.hf.space"
318
+ );
319
+ const result = await client.predict("/predict", { file: yourFile });
320
+ ```
321
+ """
322
+ )
323
+
324
+
325
+ # ────────────────────────────────────────────────────────────────
326
+ # Launch
327
+ # ────────────────────────────────────────────────────────────────
328
+
329
+ if __name__ == "__main__":
330
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ scipy
4
+ pandas
5
+ gradio
sleep_stage_cnn.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:590c0fce8a200a67c6b277bd56e2c45b49238c2a73e56562fa6fd56fcd1ba80d
3
+ size 282749