Spaces:
Runtime error
Runtime error
| """ | |
| Hugging Face Space β Sleep Stage Classification | |
| ================================================ | |
| Gradio app that serves the pre-trained CNN model for inference. | |
| Callable from any frontend via the Gradio API. | |
| Space URL: https://<your-username>-sleep-stage-classifier.hf.space | |
| """ | |
| import io | |
| import os | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from collections import Counter | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Constants | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SFREQ = 100 | |
| EPOCH_SAMPLES = 3000 # 30 seconds Γ 100 Hz | |
| STAGES = ["Wake", "N1", "N2", "N3", "N4", "REM"] | |
| MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sleep_stage_cnn.pth") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Definition (must match training architecture exactly) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SleepStageCNN(nn.Module): | |
| """ | |
| 1D Convolutional Neural Network for Sleep Stage Classification. | |
| Architecture matches the training notebook. | |
| """ | |
| def __init__(self, n_channels=1, n_classes=6): | |
| super().__init__() | |
| self.network = nn.Sequential( | |
| # Block 1: large receptive field for slow-wave features | |
| nn.Conv1d(n_channels, 32, kernel_size=50, stride=6), | |
| nn.BatchNorm1d(32), | |
| nn.ReLU(), | |
| nn.MaxPool1d(8), | |
| # Block 2: finer feature extraction | |
| nn.Conv1d(32, 64, kernel_size=8), | |
| nn.BatchNorm1d(64), | |
| nn.ReLU(), | |
| nn.MaxPool1d(8), | |
| # Classifier head | |
| nn.Flatten(), | |
| nn.Linear(64 * 6, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, n_classes), | |
| ) | |
| def forward(self, x): | |
| return self.network(x) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load Model at startup | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cpu") | |
| model = SleepStageCNN(n_channels=1, n_classes=6) | |
| if os.path.exists(MODEL_PATH): | |
| checkpoint = torch.load( | |
| MODEL_PATH, map_location=device, weights_only=False | |
| ) | |
| if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| # Remap bare Sequential keys (e.g. "0.weight") β "network.0.weight" | |
| if any(k.split(".")[0].isdigit() for k in state_dict.keys()): | |
| state_dict = {"network." + k: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval().to(device) | |
| print(f"β Model loaded from {MODEL_PATH}") | |
| else: | |
| raise FileNotFoundError( | |
| f"Model file not found at {MODEL_PATH}. " | |
| "Upload sleep_stage_cnn.pth to this Space." | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Inference Function | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def classify_eeg(signal: np.ndarray) -> dict: | |
| """ | |
| Run inference on a 1D EEG signal. | |
| Parameters | |
| ---------- | |
| signal : np.ndarray | |
| Raw EEG data (1D array, assumed 100 Hz sampling rate). | |
| Returns | |
| ------- | |
| dict with keys: | |
| - epochs: list of {epoch, stage, confidence} | |
| - summary: dict of stage β "count (percentage%)" | |
| """ | |
| if len(signal) < EPOCH_SAMPLES: | |
| return { | |
| "error": ( | |
| f"Signal too short. Need at least {EPOCH_SAMPLES} samples " | |
| f"(30s at 100 Hz), got {len(signal)}." | |
| ) | |
| } | |
| predictions = [] | |
| for i in range(0, len(signal) - EPOCH_SAMPLES + 1, EPOCH_SAMPLES): | |
| epoch = signal[i: i + EPOCH_SAMPLES] | |
| # Z-score normalize | |
| mean = epoch.mean() | |
| std = epoch.std() | |
| if std == 0: | |
| std = 1.0 | |
| epoch_norm = (epoch - mean) / std | |
| # Forward pass | |
| x = torch.tensor( | |
| epoch_norm, dtype=torch.float32 | |
| ).unsqueeze(0).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] | |
| pred_idx = int(logits.argmax().item()) | |
| predictions.append({ | |
| "epoch": len(predictions) + 1, | |
| "stage": STAGES[pred_idx], | |
| "confidence": round(float(max(probs)), 4), | |
| "probabilities": { | |
| STAGES[j]: round(float(probs[j]), 4) | |
| for j in range(len(STAGES)) | |
| }, | |
| }) | |
| # Summary statistics | |
| counts = Counter(p["stage"] for p in predictions) | |
| total = len(predictions) | |
| return { | |
| "epochs": predictions, | |
| "summary": { | |
| stage: { | |
| "count": counts.get(stage, 0), | |
| "percentage": round(counts.get(stage, 0) / total * 100, 1) | |
| } | |
| for stage in STAGES | |
| }, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # File Processor (called by Gradio UI) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_file(file) -> tuple: | |
| """ | |
| Process uploaded EEG file and return readable results + raw JSON. | |
| Parameters | |
| ---------- | |
| file : file-like or str path | |
| Uploaded CSV / TXT / NPY file. | |
| Returns | |
| ------- | |
| (text_output, json_output) | |
| """ | |
| if file is None: | |
| return "β οΈ Please upload a file.", None | |
| try: | |
| # Determine file type and load signal | |
| name = file.name.lower() if hasattr(file, "name") else str(file).lower() | |
| if name.endswith(".npy"): | |
| signal = np.load(file) | |
| if signal.ndim > 1: | |
| signal = signal.flatten() | |
| else: | |
| # CSV or TXT β first column | |
| df = pd.read_csv(file, header=None, sep=None, engine="python") | |
| signal = df.iloc[:, 0].values.astype(np.float64) | |
| # Run inference | |
| result = classify_eeg(signal) | |
| if "error" in result: | |
| return f"β {result['error']}", None | |
| # Build readable text output | |
| lines = [] | |
| lines.append(f"π Total epochs classified: {len(result['epochs'])}") | |
| lines.append("") | |
| lines.append("π Stage Distribution:") | |
| lines.append("-" * 40) | |
| for stage, stats in result["summary"].items(): | |
| bar = "β" * int(stats["percentage"] / 2) | |
| lines.append(f" {stage:6s}: {stats['count']:4d} ({stats['percentage']:5.1f}%) {bar}") | |
| lines.append("") | |
| lines.append("π Epoch Details (first 20):") | |
| lines.append("-" * 40) | |
| for ep in result["epochs"][:20]: | |
| lines.append( | |
| f" Epoch {ep['epoch']:>3d}: {ep['stage']:5s} " | |
| f"confidence {ep['confidence']*100:.1f}%" | |
| ) | |
| text_output = "\n".join(lines) | |
| json_output = result # Gradio will auto-serialize to JSON | |
| return text_output, json_output | |
| except Exception as e: | |
| return f"β Error: {str(e)}", None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio Interface | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks( | |
| title="Sleep Stage Classifier", | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="slate", | |
| ), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # π΄ Sleep Stage Classification | |
| Upload a **CSV**, **TXT**, or **NPY** file containing raw EEG signal data. | |
| The model assumes a **100 Hz sampling rate** and classifies the signal | |
| into 30-second epochs. | |
| | Stage | Description | | |
| |-------|-------------| | |
| | **Wake** | Awake, eyes open/closed | | |
| | **N1** | Light sleep, transition | | |
| | **N2** | Deeper sleep, spindles + K-complexes | | |
| | **N3** | Slow-wave sleep (deep) | | |
| | **N4** | Very deep slow-wave sleep | | |
| | **REM** | Rapid eye movement (dreaming) | | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload EEG file", | |
| file_types=[".csv", ".txt", ".npy"], | |
| ) | |
| btn = gr.Button("π Classify", variant="primary", size="lg") | |
| gr.Markdown("π‘ **Tip:** Upload a single-column CSV with EEG amplitude values (100 Hz).") | |
| with gr.Column(scale=2): | |
| text_output = gr.Textbox( | |
| label="Results", | |
| lines=20, | |
| interactive=False, | |
| ) | |
| json_output = gr.JSON( | |
| label="Raw JSON (for API integration)", | |
| ) | |
| btn.click( | |
| fn=process_file, | |
| inputs=[file_input], | |
| outputs=[text_output, json_output], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π API Access | |
| You can call this Space programmatically from any frontend: | |
| ```bash | |
| pip install gradio_client | |
| ``` | |
| ```python | |
| from gradio_client import Client | |
| client = Client("<your-username>/sleep-stage-classifier") | |
| result = client.predict(file="path/to/eeg.csv") | |
| print(result) | |
| ``` | |
| Or from JavaScript in your Lovable app: | |
| ```javascript | |
| import { Client } from "@gradio/client"; | |
| const client = await Client.connect( | |
| "https://<your-username>-sleep-stage-classifier.hf.space" | |
| ); | |
| const result = await client.predict("/predict", { file: yourFile }); | |
| ``` | |
| """ | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Launch | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo.launch() | |