""" Hugging Face Space — Sleep Stage Classification ================================================ Gradio app that serves the pre-trained CNN model for inference. Callable from any frontend via the Gradio API. Space URL: https://-sleep-stage-classifier.hf.space """ import io import os import json import numpy as np import pandas as pd import gradio as gr import torch import torch.nn as nn from collections import Counter # ──────────────────────────────────────────────────────────────── # Constants # ──────────────────────────────────────────────────────────────── SFREQ = 100 EPOCH_SAMPLES = 3000 # 30 seconds × 100 Hz STAGES = ["Wake", "N1", "N2", "N3", "N4", "REM"] MODEL_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sleep_stage_cnn.pth") # ──────────────────────────────────────────────────────────────── # Model Definition (must match training architecture exactly) # ──────────────────────────────────────────────────────────────── class SleepStageCNN(nn.Module): """ 1D Convolutional Neural Network for Sleep Stage Classification. Architecture matches the training notebook. """ def __init__(self, n_channels=1, n_classes=6): super().__init__() self.network = nn.Sequential( # Block 1: large receptive field for slow-wave features nn.Conv1d(n_channels, 32, kernel_size=50, stride=6), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(8), # Block 2: finer feature extraction nn.Conv1d(32, 64, kernel_size=8), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(8), # Classifier head nn.Flatten(), nn.Linear(64 * 6, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, n_classes), ) def forward(self, x): return self.network(x) # ──────────────────────────────────────────────────────────────── # Load Model at startup # ──────────────────────────────────────────────────────────────── device = torch.device("cpu") model = SleepStageCNN(n_channels=1, n_classes=6) if os.path.exists(MODEL_PATH): checkpoint = torch.load( MODEL_PATH, map_location=device, weights_only=False ) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # Remap bare Sequential keys (e.g. "0.weight") → "network.0.weight" if any(k.split(".")[0].isdigit() for k in state_dict.keys()): state_dict = {"network." + k: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval().to(device) print(f"✅ Model loaded from {MODEL_PATH}") else: raise FileNotFoundError( f"Model file not found at {MODEL_PATH}. " "Upload sleep_stage_cnn.pth to this Space." ) # ──────────────────────────────────────────────────────────────── # Inference Function # ──────────────────────────────────────────────────────────────── def classify_eeg(signal: np.ndarray) -> dict: """ Run inference on a 1D EEG signal. Parameters ---------- signal : np.ndarray Raw EEG data (1D array, assumed 100 Hz sampling rate). Returns ------- dict with keys: - epochs: list of {epoch, stage, confidence} - summary: dict of stage → "count (percentage%)" """ if len(signal) < EPOCH_SAMPLES: return { "error": ( f"Signal too short. Need at least {EPOCH_SAMPLES} samples " f"(30s at 100 Hz), got {len(signal)}." ) } predictions = [] for i in range(0, len(signal) - EPOCH_SAMPLES + 1, EPOCH_SAMPLES): epoch = signal[i: i + EPOCH_SAMPLES] # Z-score normalize mean = epoch.mean() std = epoch.std() if std == 0: std = 1.0 epoch_norm = (epoch - mean) / std # Forward pass x = torch.tensor( epoch_norm, dtype=torch.float32 ).unsqueeze(0).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1).cpu().numpy()[0] pred_idx = int(logits.argmax().item()) predictions.append({ "epoch": len(predictions) + 1, "stage": STAGES[pred_idx], "confidence": round(float(max(probs)), 4), "probabilities": { STAGES[j]: round(float(probs[j]), 4) for j in range(len(STAGES)) }, }) # Summary statistics counts = Counter(p["stage"] for p in predictions) total = len(predictions) return { "epochs": predictions, "summary": { stage: { "count": counts.get(stage, 0), "percentage": round(counts.get(stage, 0) / total * 100, 1) } for stage in STAGES }, } # ──────────────────────────────────────────────────────────────── # File Processor (called by Gradio UI) # ──────────────────────────────────────────────────────────────── def process_file(file) -> tuple: """ Process uploaded EEG file and return readable results + raw JSON. Parameters ---------- file : file-like or str path Uploaded CSV / TXT / NPY file. Returns ------- (text_output, json_output) """ if file is None: return "⚠️ Please upload a file.", None try: # Determine file type and load signal name = file.name.lower() if hasattr(file, "name") else str(file).lower() if name.endswith(".npy"): signal = np.load(file) if signal.ndim > 1: signal = signal.flatten() else: # CSV or TXT — first column df = pd.read_csv(file, header=None, sep=None, engine="python") signal = df.iloc[:, 0].values.astype(np.float64) # Run inference result = classify_eeg(signal) if "error" in result: return f"❌ {result['error']}", None # Build readable text output lines = [] lines.append(f"📊 Total epochs classified: {len(result['epochs'])}") lines.append("") lines.append("📋 Stage Distribution:") lines.append("-" * 40) for stage, stats in result["summary"].items(): bar = "█" * int(stats["percentage"] / 2) lines.append(f" {stage:6s}: {stats['count']:4d} ({stats['percentage']:5.1f}%) {bar}") lines.append("") lines.append("📝 Epoch Details (first 20):") lines.append("-" * 40) for ep in result["epochs"][:20]: lines.append( f" Epoch {ep['epoch']:>3d}: {ep['stage']:5s} " f"confidence {ep['confidence']*100:.1f}%" ) text_output = "\n".join(lines) json_output = result # Gradio will auto-serialize to JSON return text_output, json_output except Exception as e: return f"❌ Error: {str(e)}", None # ──────────────────────────────────────────────────────────────── # Gradio Interface # ──────────────────────────────────────────────────────────────── with gr.Blocks( title="Sleep Stage Classifier", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="slate", ), ) as demo: gr.Markdown( """ # 😴 Sleep Stage Classification Upload a **CSV**, **TXT**, or **NPY** file containing raw EEG signal data. The model assumes a **100 Hz sampling rate** and classifies the signal into 30-second epochs. | Stage | Description | |-------|-------------| | **Wake** | Awake, eyes open/closed | | **N1** | Light sleep, transition | | **N2** | Deeper sleep, spindles + K-complexes | | **N3** | Slow-wave sleep (deep) | | **N4** | Very deep slow-wave sleep | | **REM** | Rapid eye movement (dreaming) | """ ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload EEG file", file_types=[".csv", ".txt", ".npy"], ) btn = gr.Button("🔍 Classify", variant="primary", size="lg") gr.Markdown("💡 **Tip:** Upload a single-column CSV with EEG amplitude values (100 Hz).") with gr.Column(scale=2): text_output = gr.Textbox( label="Results", lines=20, interactive=False, ) json_output = gr.JSON( label="Raw JSON (for API integration)", ) btn.click( fn=process_file, inputs=[file_input], outputs=[text_output, json_output], ) gr.Markdown( """ --- ### 🔌 API Access You can call this Space programmatically from any frontend: ```bash pip install gradio_client ``` ```python from gradio_client import Client client = Client("/sleep-stage-classifier") result = client.predict(file="path/to/eeg.csv") print(result) ``` Or from JavaScript in your Lovable app: ```javascript import { Client } from "@gradio/client"; const client = await Client.connect( "https://-sleep-stage-classifier.hf.space" ); const result = await client.predict("/predict", { file: yourFile }); ``` """ ) # ──────────────────────────────────────────────────────────────── # Launch # ──────────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch()