EEG_sleep / app.py
TunisianCoder's picture
Update app.py
dc21fc5 verified
"""
Hugging Face Space β€” Sleep Stage Classification
================================================
Gradio app that serves the pre-trained CNN model for inference.
Callable from any frontend via the Gradio API.
Space URL: https://<your-username>-sleep-stage-classifier.hf.space
"""
import io
import os
import json
import numpy as np
import pandas as pd
import gradio as gr
import torch
import torch.nn as nn
from collections import Counter
# ────────────────────────────────────────────────────────────────
# Constants
# ────────────────────────────────────────────────────────────────
SFREQ = 100
EPOCH_SAMPLES = 3000 # 30 seconds Γ— 100 Hz
STAGES = ["Wake", "N1", "N2", "N3", "N4", "REM"]
MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sleep_stage_cnn.pth")
# ────────────────────────────────────────────────────────────────
# Model Definition (must match training architecture exactly)
# ────────────────────────────────────────────────────────────────
class SleepStageCNN(nn.Module):
"""
1D Convolutional Neural Network for Sleep Stage Classification.
Architecture matches the training notebook.
"""
def __init__(self, n_channels=1, n_classes=6):
super().__init__()
self.network = nn.Sequential(
# Block 1: large receptive field for slow-wave features
nn.Conv1d(n_channels, 32, kernel_size=50, stride=6),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(8),
# Block 2: finer feature extraction
nn.Conv1d(32, 64, kernel_size=8),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(8),
# Classifier head
nn.Flatten(),
nn.Linear(64 * 6, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, n_classes),
)
def forward(self, x):
return self.network(x)
# ────────────────────────────────────────────────────────────────
# Load Model at startup
# ────────────────────────────────────────────────────────────────
device = torch.device("cpu")
model = SleepStageCNN(n_channels=1, n_classes=6)
if os.path.exists(MODEL_PATH):
checkpoint = torch.load(
MODEL_PATH, map_location=device, weights_only=False
)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Remap bare Sequential keys (e.g. "0.weight") β†’ "network.0.weight"
if any(k.split(".")[0].isdigit() for k in state_dict.keys()):
state_dict = {"network." + k: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval().to(device)
print(f"βœ… Model loaded from {MODEL_PATH}")
else:
raise FileNotFoundError(
f"Model file not found at {MODEL_PATH}. "
"Upload sleep_stage_cnn.pth to this Space."
)
# ────────────────────────────────────────────────────────────────
# Inference Function
# ────────────────────────────────────────────────────────────────
def classify_eeg(signal: np.ndarray) -> dict:
"""
Run inference on a 1D EEG signal.
Parameters
----------
signal : np.ndarray
Raw EEG data (1D array, assumed 100 Hz sampling rate).
Returns
-------
dict with keys:
- epochs: list of {epoch, stage, confidence}
- summary: dict of stage β†’ "count (percentage%)"
"""
if len(signal) < EPOCH_SAMPLES:
return {
"error": (
f"Signal too short. Need at least {EPOCH_SAMPLES} samples "
f"(30s at 100 Hz), got {len(signal)}."
)
}
predictions = []
for i in range(0, len(signal) - EPOCH_SAMPLES + 1, EPOCH_SAMPLES):
epoch = signal[i: i + EPOCH_SAMPLES]
# Z-score normalize
mean = epoch.mean()
std = epoch.std()
if std == 0:
std = 1.0
epoch_norm = (epoch - mean) / std
# Forward pass
x = torch.tensor(
epoch_norm, dtype=torch.float32
).unsqueeze(0).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
pred_idx = int(logits.argmax().item())
predictions.append({
"epoch": len(predictions) + 1,
"stage": STAGES[pred_idx],
"confidence": round(float(max(probs)), 4),
"probabilities": {
STAGES[j]: round(float(probs[j]), 4)
for j in range(len(STAGES))
},
})
# Summary statistics
counts = Counter(p["stage"] for p in predictions)
total = len(predictions)
return {
"epochs": predictions,
"summary": {
stage: {
"count": counts.get(stage, 0),
"percentage": round(counts.get(stage, 0) / total * 100, 1)
}
for stage in STAGES
},
}
# ────────────────────────────────────────────────────────────────
# File Processor (called by Gradio UI)
# ────────────────────────────────────────────────────────────────
def process_file(file) -> tuple:
"""
Process uploaded EEG file and return readable results + raw JSON.
Parameters
----------
file : file-like or str path
Uploaded CSV / TXT / NPY file.
Returns
-------
(text_output, json_output)
"""
if file is None:
return "⚠️ Please upload a file.", None
try:
# Determine file type and load signal
name = file.name.lower() if hasattr(file, "name") else str(file).lower()
if name.endswith(".npy"):
signal = np.load(file)
if signal.ndim > 1:
signal = signal.flatten()
else:
# CSV or TXT β€” first column
df = pd.read_csv(file, header=None, sep=None, engine="python")
signal = df.iloc[:, 0].values.astype(np.float64)
# Run inference
result = classify_eeg(signal)
if "error" in result:
return f"❌ {result['error']}", None
# Build readable text output
lines = []
lines.append(f"πŸ“Š Total epochs classified: {len(result['epochs'])}")
lines.append("")
lines.append("πŸ“‹ Stage Distribution:")
lines.append("-" * 40)
for stage, stats in result["summary"].items():
bar = "β–ˆ" * int(stats["percentage"] / 2)
lines.append(f" {stage:6s}: {stats['count']:4d} ({stats['percentage']:5.1f}%) {bar}")
lines.append("")
lines.append("πŸ“ Epoch Details (first 20):")
lines.append("-" * 40)
for ep in result["epochs"][:20]:
lines.append(
f" Epoch {ep['epoch']:>3d}: {ep['stage']:5s} "
f"confidence {ep['confidence']*100:.1f}%"
)
text_output = "\n".join(lines)
json_output = result # Gradio will auto-serialize to JSON
return text_output, json_output
except Exception as e:
return f"❌ Error: {str(e)}", None
# ────────────────────────────────────────────────────────────────
# Gradio Interface
# ────────────────────────────────────────────────────────────────
with gr.Blocks(
title="Sleep Stage Classifier",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
),
) as demo:
gr.Markdown(
"""
# 😴 Sleep Stage Classification
Upload a **CSV**, **TXT**, or **NPY** file containing raw EEG signal data.
The model assumes a **100 Hz sampling rate** and classifies the signal
into 30-second epochs.
| Stage | Description |
|-------|-------------|
| **Wake** | Awake, eyes open/closed |
| **N1** | Light sleep, transition |
| **N2** | Deeper sleep, spindles + K-complexes |
| **N3** | Slow-wave sleep (deep) |
| **N4** | Very deep slow-wave sleep |
| **REM** | Rapid eye movement (dreaming) |
"""
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload EEG file",
file_types=[".csv", ".txt", ".npy"],
)
btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
gr.Markdown("πŸ’‘ **Tip:** Upload a single-column CSV with EEG amplitude values (100 Hz).")
with gr.Column(scale=2):
text_output = gr.Textbox(
label="Results",
lines=20,
interactive=False,
)
json_output = gr.JSON(
label="Raw JSON (for API integration)",
)
btn.click(
fn=process_file,
inputs=[file_input],
outputs=[text_output, json_output],
)
gr.Markdown(
"""
---
### πŸ”Œ API Access
You can call this Space programmatically from any frontend:
```bash
pip install gradio_client
```
```python
from gradio_client import Client
client = Client("<your-username>/sleep-stage-classifier")
result = client.predict(file="path/to/eeg.csv")
print(result)
```
Or from JavaScript in your Lovable app:
```javascript
import { Client } from "@gradio/client";
const client = await Client.connect(
"https://<your-username>-sleep-stage-classifier.hf.space"
);
const result = await client.predict("/predict", { file: yourFile });
```
"""
)
# ────────────────────────────────────────────────────────────────
# Launch
# ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch()