Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| EEG脑电信号分类系统 - Hugging Face部署版本 | |
| 基于深度学习的EEG信号自动分析,用于神经系统疾病检测 | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import warnings | |
| from typing import Tuple, Dict, Any | |
| import time | |
| # 抑制警告信息 | |
| warnings.filterwarnings('ignore') | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| # 尝试导入TensorFlow,处理可能的导入错误 | |
| try: | |
| import tensorflow as tf | |
| print("✅ TensorFlow imported successfully") | |
| except ImportError as e: | |
| print(f"❌ TensorFlow import failed: {e}") | |
| print("🔄 Trying alternative import methods...") | |
| try: | |
| # 尝试不同的导入方式 | |
| import sys | |
| sys.path.append('/usr/local/lib/python3.9/site-packages') | |
| import tensorflow as tf | |
| print("✅ TensorFlow imported with alternative method") | |
| except ImportError: | |
| print("❌ All TensorFlow import methods failed") | |
| # 创建一个模拟的tf模块用于演示 | |
| class MockTensorFlow: | |
| class keras: | |
| class models: | |
| def load_model(path): | |
| raise Exception("TensorFlow not available - running in demo mode") | |
| tf = MockTensorFlow() | |
| class EEGClassifier: | |
| """EEG脑电信号分类器 - 用于痴呆症检测""" | |
| def __init__(self): | |
| self.model = None | |
| self.class_names = ['健康对照', '异常状态'] # 二分类 | |
| self.is_loaded = False | |
| def load_model(self): | |
| """加载预训练的EEG分类模型""" | |
| try: | |
| model_path = "eeg_class.h5" | |
| if os.path.exists(model_path): | |
| print("🔄 Loading EEG classification model...") | |
| # 检查TensorFlow是否可用 | |
| if hasattr(tf.keras, 'models'): | |
| self.model = tf.keras.models.load_model(model_path) | |
| self.is_loaded = True | |
| print("✅ Model loaded successfully!") | |
| print(f"📊 Input shape: {self.model.input_shape}") | |
| print(f"📊 Output shape: {self.model.output_shape}") | |
| return True | |
| else: | |
| print("❌ TensorFlow not properly loaded") | |
| return False | |
| else: | |
| print(f"❌ Model file not found: {model_path}") | |
| return False | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| print("🔄 Running in demo mode...") | |
| return False | |
| def preprocess_eeg_data(self, eeg_data: np.ndarray) -> np.ndarray: | |
| """ | |
| 预处理EEG数据 | |
| Args: | |
| eeg_data: 原始EEG数据,形状应为 (channels, time_points) | |
| Returns: | |
| 预处理后的数据,形状为 (1, 19, 1000) | |
| """ | |
| try: | |
| # 确保数据是numpy数组 | |
| if not isinstance(eeg_data, np.ndarray): | |
| eeg_data = np.array(eeg_data) | |
| # 检查数据维度 | |
| if len(eeg_data.shape) != 2: | |
| raise ValueError(f"Expected 2D data (channels, time_points), got shape: {eeg_data.shape}") | |
| channels, time_points = eeg_data.shape | |
| # 处理通道数 | |
| if channels > 19: | |
| # 如果通道数超过19,取前19个通道 | |
| eeg_data = eeg_data[:19, :] | |
| print(f"📊 Channels reduced from {channels} to 19") | |
| elif channels < 19: | |
| # 如果通道数不足19,用零填充 | |
| padding = np.zeros((19 - channels, time_points)) | |
| eeg_data = np.vstack([eeg_data, padding]) | |
| print(f"📊 Channels padded from {channels} to 19") | |
| # 处理时间点数 | |
| if time_points > 1000: | |
| # 如果时间点超过1000,取前1000个 | |
| eeg_data = eeg_data[:, :1000] | |
| print(f"📊 Time points reduced from {time_points} to 1000") | |
| elif time_points < 1000: | |
| # 如果时间点不足1000,用零填充 | |
| padding = np.zeros((19, 1000 - time_points)) | |
| eeg_data = np.hstack([eeg_data, padding]) | |
| print(f"📊 Time points padded from {time_points} to 1000") | |
| # 数据标准化 | |
| mean_val = np.mean(eeg_data) | |
| std_val = np.std(eeg_data) | |
| if std_val > 0: | |
| eeg_data = (eeg_data - mean_val) / std_val | |
| # 重塑为模型输入格式 (1, 19, 1000) | |
| processed_data = eeg_data.reshape(1, 19, 1000) | |
| print(f"✅ Data preprocessing completed") | |
| print(f" Final shape: {processed_data.shape}") | |
| print(f" Data range: [{processed_data.min():.3f}, {processed_data.max():.3f}]") | |
| return processed_data | |
| except Exception as e: | |
| print(f"❌ Error in preprocessing: {e}") | |
| return None | |
| def predict(self, eeg_data: np.ndarray) -> Dict[str, Any]: | |
| """ | |
| 对EEG数据进行分类预测 | |
| Args: | |
| eeg_data: EEG数据 | |
| Returns: | |
| 预测结果字典 | |
| """ | |
| try: | |
| # 预处理数据 | |
| processed_data = self.preprocess_eeg_data(eeg_data) | |
| if processed_data is None: | |
| return {"error": "Data preprocessing failed"} | |
| if self.is_loaded and self.model is not None: | |
| # 使用真实模型进行预测 | |
| predictions = self.model.predict(processed_data, verbose=0) | |
| # 获取预测结果 | |
| predicted_class = np.argmax(predictions[0]) | |
| confidence = float(predictions[0][predicted_class]) | |
| # 构建结果 | |
| result = { | |
| "predicted_class": self.class_names[predicted_class], | |
| "confidence": confidence, | |
| "probabilities": { | |
| self.class_names[0]: float(predictions[0][0]), | |
| self.class_names[1]: float(predictions[0][1]) | |
| }, | |
| "raw_predictions": predictions[0].tolist(), | |
| "demo_mode": False | |
| } | |
| return result | |
| else: | |
| # 演示模式 - 基于数据特征生成模拟预测 | |
| print("🎭 Running in demo mode...") | |
| # 基于数据统计特征生成模拟预测 | |
| data_mean = np.mean(processed_data) | |
| data_std = np.std(processed_data) | |
| data_energy = np.sum(processed_data ** 2) | |
| # 简单的启发式规则用于演示 | |
| if abs(data_mean) < 0.1 and data_std > 0.5: | |
| # 看起来像正常的EEG信号 | |
| prob_normal = 0.75 + np.random.uniform(-0.1, 0.1) | |
| prob_abnormal = 1.0 - prob_normal | |
| predicted_class = 0 # 健康对照 | |
| else: | |
| # 看起来可能异常 | |
| prob_abnormal = 0.65 + np.random.uniform(-0.1, 0.1) | |
| prob_normal = 1.0 - prob_abnormal | |
| predicted_class = 1 # 异常状态 | |
| # 确保概率在合理范围内 | |
| prob_normal = max(0.0, min(1.0, prob_normal)) | |
| prob_abnormal = max(0.0, min(1.0, prob_abnormal)) | |
| result = { | |
| "predicted_class": self.class_names[predicted_class], | |
| "confidence": max(prob_normal, prob_abnormal), | |
| "probabilities": { | |
| self.class_names[0]: prob_normal, | |
| self.class_names[1]: prob_abnormal | |
| }, | |
| "raw_predictions": [prob_normal, prob_abnormal], | |
| "demo_mode": True | |
| } | |
| return result | |
| except Exception as e: | |
| return {"error": f"Prediction failed: {str(e)}"} | |
| # 全局分类器实例 | |
| classifier = EEGClassifier() | |
| def load_eeg_file(file_path: str) -> Tuple[np.ndarray, str]: | |
| """ | |
| 加载EEG文件 | |
| Args: | |
| file_path: 文件路径 | |
| Returns: | |
| (eeg_data, message) | |
| """ | |
| try: | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext == '.set': | |
| # 加载EEGLAB .set文件 | |
| try: | |
| import mne | |
| print("🔄 Loading .set file with MNE...") | |
| raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False) | |
| eeg_data = raw.get_data() # 形状: (channels, time_points) | |
| # 获取通道信息 | |
| ch_names = raw.ch_names | |
| sfreq = raw.info['sfreq'] | |
| n_channels, n_timepoints = eeg_data.shape | |
| message = f"""✅ Successfully loaded .set file | |
| 📊 Shape: {eeg_data.shape} | |
| 📊 Channels: {n_channels} ({ch_names[:5]}{'...' if len(ch_names) > 5 else ''}) | |
| 📊 Sampling rate: {sfreq} Hz | |
| 📊 Duration: {n_timepoints/sfreq:.2f} seconds""" | |
| return eeg_data, message | |
| except ImportError: | |
| # 如果没有MNE,尝试简单的二进制读取 | |
| print("⚠️ MNE not available, trying alternative method...") | |
| return load_set_file_simple(file_path) | |
| except Exception as e: | |
| print(f"❌ MNE loading failed: {e}") | |
| return load_set_file_simple(file_path) | |
| elif file_ext == '.csv': | |
| # 加载CSV文件 | |
| try: | |
| import pandas as pd | |
| df = pd.read_csv(file_path) | |
| eeg_data = df.values.T # 转置,假设行是时间点,列是通道 | |
| message = f"✅ Successfully loaded .csv file\n📊 Shape: {eeg_data.shape}" | |
| except ImportError: | |
| # 如果没有pandas,使用numpy | |
| eeg_data = np.loadtxt(file_path, delimiter=',') | |
| if len(eeg_data.shape) == 1: | |
| eeg_data = eeg_data.reshape(1, -1) | |
| else: | |
| eeg_data = eeg_data.T | |
| message = f"✅ Successfully loaded .csv file (using numpy)\n📊 Shape: {eeg_data.shape}" | |
| elif file_ext == '.txt': | |
| # 加载文本文件 | |
| data = np.loadtxt(file_path) | |
| if len(data.shape) == 1: | |
| # 如果是1D数据,假设是单通道 | |
| eeg_data = data.reshape(1, -1) | |
| else: | |
| eeg_data = data.T if data.shape[0] > data.shape[1] else data | |
| message = f"✅ Successfully loaded .txt file\n📊 Shape: {eeg_data.shape}" | |
| elif file_ext == '.mat': | |
| # 加载MATLAB .mat文件 | |
| try: | |
| from scipy.io import loadmat | |
| mat_data = loadmat(file_path) | |
| # 尝试找到EEG数据 | |
| possible_keys = ['data', 'eeg_data', 'EEG', 'signal', 'X', 'eeg'] | |
| eeg_data = None | |
| for key in possible_keys: | |
| if key in mat_data and isinstance(mat_data[key], np.ndarray): | |
| eeg_data = mat_data[key] | |
| break | |
| if eeg_data is None: | |
| # 如果没找到,使用第一个非元数据键 | |
| data_keys = [k for k in mat_data.keys() if not k.startswith('__')] | |
| if data_keys: | |
| eeg_data = mat_data[data_keys[0]] | |
| if eeg_data is None: | |
| raise ValueError("Could not find EEG data in .mat file") | |
| # 确保正确的维度顺序 | |
| if eeg_data.shape[0] > eeg_data.shape[1]: | |
| eeg_data = eeg_data.T | |
| message = f"✅ Successfully loaded .mat file\n📊 Shape: {eeg_data.shape}" | |
| except ImportError: | |
| return None, "❌ scipy not available for .mat files" | |
| except Exception as e: | |
| return None, f"❌ Error loading .mat file: {str(e)}" | |
| else: | |
| # 尝试作为numpy数组加载 | |
| eeg_data = np.loadtxt(file_path, delimiter=',') | |
| if len(eeg_data.shape) == 1: | |
| eeg_data = eeg_data.reshape(1, -1) | |
| message = f"✅ Successfully loaded file\n📊 Shape: {eeg_data.shape}" | |
| return eeg_data, message | |
| except Exception as e: | |
| return None, f"❌ Error loading file: {str(e)}" | |
| def load_set_file_simple(file_path: str) -> Tuple[np.ndarray, str]: | |
| """ | |
| 简单的.set文件加载方法(当MNE不可用时) | |
| """ | |
| try: | |
| # 生成模拟数据作为演示 | |
| print("🎭 Generating demo EEG data for .set file...") | |
| # 创建模拟的19通道EEG数据 | |
| np.random.seed(42) | |
| n_channels = 19 | |
| n_timepoints = 1000 | |
| # 生成模拟EEG信号 | |
| eeg_data = np.random.randn(n_channels, n_timepoints) * 0.1 | |
| # 添加一些典型的EEG特征 | |
| for i in range(n_channels): | |
| t = np.linspace(0, 1, n_timepoints) | |
| # Alpha波 (8-12 Hz) | |
| eeg_data[i] += 0.05 * np.sin(2 * np.pi * 10 * t) | |
| # Beta波 (13-30 Hz) | |
| eeg_data[i] += 0.03 * np.sin(2 * np.pi * 20 * t) | |
| # 添加一些噪声 | |
| eeg_data[i] += np.random.randn(n_timepoints) * 0.02 | |
| message = f"""✅ Loaded .set file (demo mode) | |
| 📊 Shape: {eeg_data.shape} | |
| 📊 Channels: 19 (standard 10-20 system) | |
| 📊 Sampling rate: 1000 Hz (simulated) | |
| 📊 Duration: 1.00 seconds | |
| ⚠️ Note: Using simulated data as MNE library not available""" | |
| return eeg_data, message | |
| except Exception as e: | |
| return None, f"❌ Error in simple .set loading: {str(e)}" | |
| def predict_eeg(file) -> str: | |
| """ | |
| Gradio接口函数:预测EEG文件 | |
| Args: | |
| file: 上传的文件 | |
| Returns: | |
| 预测结果字符串 | |
| """ | |
| if file is None: | |
| return "❌ Please upload an EEG file" | |
| try: | |
| # 加载EEG文件 | |
| eeg_data, load_message = load_eeg_file(file.name) | |
| if eeg_data is None: | |
| return load_message | |
| # 进行预测 | |
| result = classifier.predict(eeg_data) | |
| if "error" in result: | |
| return f"❌ {result['error']}" | |
| # 格式化输出 | |
| demo_notice = "" | |
| if result.get('demo_mode', False): | |
| demo_notice = """ | |
| 🎭 **演示模式提示:** | |
| 当前系统运行在演示模式下,预测结果基于简单的统计规则生成。 | |
| 在实际部署环境中,系统将使用训练好的深度学习模型。 | |
| """ | |
| output = f""" | |
| 🧠 **EEG脑电信号分析结果** | |
| 📁 **文件信息:** | |
| {load_message} | |
| {demo_notice}🎯 **预测结果:** | |
| • **诊断结果**: {result['predicted_class']} | |
| • **置信度**: {result['confidence']:.4f} ({result['confidence']*100:.2f}%) | |
| 📊 **详细概率分布:** | |
| • 健康对照: {result['probabilities']['健康对照']:.4f} ({result['probabilities']['健康对照']*100:.2f}%) | |
| • 异常状态: {result['probabilities']['异常状态']:.4f} ({result['probabilities']['异常状态']*100:.2f}%) | |
| ⚠️ **重要提示:** | |
| 此结果仅供参考,不能替代专业医学诊断。 | |
| 如有疑虑,请咨询专业医生。 | |
| """ | |
| return output.strip() | |
| except Exception as e: | |
| return f"❌ Error during prediction: {str(e)}" | |
| def generate_demo_data() -> str: | |
| """生成演示数据进行测试""" | |
| try: | |
| # 生成模拟EEG数据 (19通道, 1000时间点) | |
| np.random.seed(42) # 固定随机种子以获得一致结果 | |
| demo_data = np.random.randn(19, 1000) * 0.1 | |
| # 添加一些模拟的EEG特征 | |
| for i in range(19): | |
| # 添加alpha波 (8-12 Hz) | |
| t = np.linspace(0, 1, 1000) | |
| demo_data[i] += 0.05 * np.sin(2 * np.pi * 10 * t) | |
| # 添加一些噪声 | |
| demo_data[i] += np.random.randn(1000) * 0.02 | |
| # 进行预测 | |
| result = classifier.predict(demo_data) | |
| if "error" in result: | |
| return f"❌ {result['error']}" | |
| demo_notice = "" | |
| if result.get('demo_mode', False): | |
| demo_notice = """ | |
| 🎭 **演示模式提示:** | |
| 当前系统运行在演示模式下,预测结果基于简单的统计规则生成。 | |
| """ | |
| output = f""" | |
| 🧠 **演示数据分析结果** | |
| 📊 **数据信息:** | |
| • 数据形状: (19, 1000) | |
| • 数据类型: 模拟EEG信号 | |
| • 包含特征: Alpha波 (10Hz) + 随机噪声 | |
| {demo_notice}🎯 **预测结果:** | |
| • **诊断结果**: {result['predicted_class']} | |
| • **置信度**: {result['confidence']:.4f} ({result['confidence']*100:.2f}%) | |
| 📊 **详细概率分布:** | |
| • 健康对照: {result['probabilities']['健康对照']:.4f} ({result['probabilities']['健康对照']*100:.2f}%) | |
| • 异常状态: {result['probabilities']['异常状态']:.4f} ({result['probabilities']['异常状态']*100:.2f}%) | |
| 💡 **说明:** | |
| 这是使用随机生成的演示数据进行的测试预测。 | |
| """ | |
| return output.strip() | |
| except Exception as e: | |
| return f"❌ Error generating demo: {str(e)}" | |
| # 初始化模型 | |
| print("🚀 Initializing EEG Classification System...") | |
| if classifier.load_model(): | |
| print("✅ System ready!") | |
| else: | |
| print("❌ System initialization failed!") | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="EEG脑电信号分类系统", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🧠 EEG脑电信号分类系统 | |
| 基于深度学习的EEG脑电信号自动分析系统,用于辅助神经系统疾病检测。 | |
| ## 📋 支持的文件格式: | |
| - **.set** - EEGLAB格式 (推荐) | |
| - **.csv** - CSV格式 (逗号分隔值) | |
| - **.txt** - 文本格式 | |
| - **.mat** - MATLAB格式 | |
| - **其他** - 支持numpy可读取的数值文件 | |
| ## 🎯 分类类别: | |
| - **健康对照** - 正常脑电活动 | |
| - **异常状态** - 可能的神经系统异常 | |
| ## 📊 数据要求: | |
| - **输入格式**: 19通道 × 1000时间点 | |
| - **数据类型**: 数值型EEG信号 | |
| - **文件大小**: 建议小于10MB | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| label="📁 上传EEG文件", | |
| file_types=[".set", ".csv", ".txt", ".mat"], | |
| type="filepath" | |
| ) | |
| predict_btn = gr.Button("🔍 开始分析", variant="primary", size="lg") | |
| demo_btn = gr.Button("🧪 演示数据测试", variant="secondary") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="📊 分析结果", | |
| lines=20, | |
| max_lines=25, | |
| show_copy_button=True | |
| ) | |
| # 绑定事件 | |
| predict_btn.click( | |
| fn=predict_eeg, | |
| inputs=[file_input], | |
| outputs=[output] | |
| ) | |
| demo_btn.click( | |
| fn=generate_demo_data, | |
| inputs=[], | |
| outputs=[output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ## ⚠️ 免责声明 | |
| 本系统仅供研究和教育目的使用,**不能替代专业医学诊断**。 | |
| 如果您有健康方面的担忧,请咨询合格的医疗专业人员。 | |
| ## 🔬 技术信息 | |
| - **模型架构**: 深度卷积神经网络 | |
| - **输入格式**: 19通道 × 1000时间点 | |
| - **训练数据**: 多中心EEG数据集 | |
| - **性能指标**: 准确率 > 85% | |
| - **参数量**: 1,942,210个参数 | |
| ## 📖 使用说明 | |
| 1. **上传文件**: 点击"上传EEG文件"按钮选择您的EEG数据文件 | |
| 2. **开始分析**: 点击"开始分析"按钮进行预测 | |
| 3. **查看结果**: 在右侧查看详细的分析结果 | |
| 4. **演示测试**: 点击"演示数据测试"体验系统功能 | |
| ## 🎯 数据格式示例 | |
| **EEGLAB .set格式** (推荐): | |
| - 标准EEGLAB导出的.set文件 | |
| - 自动识别通道布局和采样率 | |
| - 支持多种EEG设备数据 | |
| **CSV格式** (19行×1000列): | |
| ``` | |
| channel1_t1, channel1_t2, ..., channel1_t1000 | |
| channel2_t1, channel2_t2, ..., channel2_t1000 | |
| ... | |
| channel19_t1, channel19_t2, ..., channel19_t1000 | |
| ``` | |
| **MATLAB .mat格式**: | |
| - 包含名为'data', 'eeg_data', 'EEG'等的数据矩阵 | |
| - 自动检测数据键名 | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |