Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for the EEG Motor Imagery Music Composer | |
| """ | |
| import numpy as np | |
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import json | |
| from config import LOG_LEVEL, LOG_FILE, CLASS_NAMES, CLASS_DESCRIPTIONS | |
| def setup_logging(): | |
| """Set up logging configuration.""" | |
| LOG_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL), | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(LOG_FILE), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| return logging.getLogger(__name__) | |
| def validate_eeg_data(data: np.ndarray) -> bool: | |
| """ | |
| Validate EEG data format and dimensions. | |
| Args: | |
| data: EEG data array | |
| Returns: | |
| bool: True if data is valid, False otherwise | |
| """ | |
| if not isinstance(data, np.ndarray): | |
| return False | |
| if data.ndim not in [2, 3]: | |
| return False | |
| if data.ndim == 2 and data.shape[0] == 0: | |
| return False | |
| if data.ndim == 3 and (data.shape[0] == 0 or data.shape[1] == 0): | |
| return False | |
| return True | |
| def format_confidence(confidence: float) -> str: | |
| """Format confidence score as percentage string.""" | |
| return f"{confidence * 100:.1f}%" | |
| def format_timestamp(timestamp: float) -> str: | |
| """Format timestamp for display.""" | |
| return time.strftime("%H:%M:%S", time.localtime(timestamp)) | |
| def get_class_emoji(class_name: str) -> str: | |
| """Get emoji representation for motor imagery class.""" | |
| emoji_map = { | |
| "left_hand": "π€", | |
| "right_hand": "π€", | |
| "neutral": "π", | |
| "left_leg": "π¦΅", | |
| "tongue": "π ", | |
| "right_leg": "π¦΅" | |
| } | |
| return emoji_map.get(class_name, "β") | |
| def create_classification_summary( | |
| predicted_class: str, | |
| confidence: float, | |
| probabilities: Dict[str, float], | |
| timestamp: Optional[float] = None | |
| ) -> Dict: | |
| """ | |
| Create a formatted summary of classification results. | |
| Args: | |
| predicted_class: Predicted motor imagery class | |
| confidence: Confidence score (0-1) | |
| probabilities: Dictionary of class probabilities | |
| timestamp: Optional timestamp | |
| Returns: | |
| Dict: Formatted classification summary | |
| """ | |
| if timestamp is None: | |
| timestamp = time.time() | |
| return { | |
| "predicted_class": predicted_class, | |
| "confidence": confidence, | |
| "confidence_percent": format_confidence(confidence), | |
| "probabilities": probabilities, | |
| "timestamp": timestamp, | |
| "formatted_time": format_timestamp(timestamp), | |
| "emoji": get_class_emoji(predicted_class), | |
| "description": CLASS_DESCRIPTIONS.get(predicted_class, predicted_class) | |
| } | |
| def save_session_data(session_data: Dict, filepath: str) -> bool: | |
| """ | |
| Save session data to JSON file. | |
| Args: | |
| session_data: Dictionary containing session information | |
| filepath: Path to save the file | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| with open(filepath, 'w') as f: | |
| json.dump(session_data, f, indent=2, default=str) | |
| return True | |
| except Exception as e: | |
| logging.error(f"Error saving session data: {e}") | |
| return False | |
| def load_session_data(filepath: str) -> Optional[Dict]: | |
| """ | |
| Load session data from JSON file. | |
| Args: | |
| filepath: Path to the JSON file | |
| Returns: | |
| Dict or None: Session data if successful, None otherwise | |
| """ | |
| try: | |
| with open(filepath, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logging.error(f"Error loading session data: {e}") | |
| return None | |
| def calculate_classification_statistics(history: List[Dict]) -> Dict: | |
| """ | |
| Calculate statistics from classification history. | |
| Args: | |
| history: List of classification results | |
| Returns: | |
| Dict: Statistics summary | |
| """ | |
| if not history: | |
| return {"total": 0, "class_counts": {}, "average_confidence": 0.0} | |
| class_counts = {} | |
| total_confidence = 0.0 | |
| for item in history: | |
| class_name = item.get("predicted_class", "unknown") | |
| confidence = item.get("confidence", 0.0) | |
| class_counts[class_name] = class_counts.get(class_name, 0) + 1 | |
| total_confidence += confidence | |
| return { | |
| "total": len(history), | |
| "class_counts": class_counts, | |
| "average_confidence": total_confidence / len(history), | |
| "most_common_class": max(class_counts, key=class_counts.get) if class_counts else None | |
| } | |
| def create_progress_bar(value: float, max_value: float = 1.0, width: int = 20) -> str: | |
| """ | |
| Create a text-based progress bar. | |
| Args: | |
| value: Current value | |
| max_value: Maximum value | |
| width: Width of progress bar in characters | |
| Returns: | |
| str: Progress bar string | |
| """ | |
| percentage = min(value / max_value, 1.0) | |
| filled = int(width * percentage) | |
| bar = "β" * filled + "β" * (width - filled) | |
| return f"[{bar}] {percentage * 100:.1f}%" | |
| def validate_audio_file(file_path: str) -> bool: | |
| """ | |
| Validate if an audio file exists and is readable. | |
| Args: | |
| file_path: Path to audio file | |
| Returns: | |
| bool: True if file is valid, False otherwise | |
| """ | |
| path = Path(file_path) | |
| if not path.exists(): | |
| return False | |
| if not path.is_file(): | |
| return False | |
| # Check file extension | |
| valid_extensions = ['.wav', '.mp3', '.flac', '.ogg'] | |
| if path.suffix.lower() not in valid_extensions: | |
| return False | |
| return True | |
| def generate_composition_filename(prefix: str = "composition") -> str: | |
| """ | |
| Generate a unique filename for composition exports. | |
| Args: | |
| prefix: Filename prefix | |
| Returns: | |
| str: Unique filename with timestamp | |
| """ | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| return f"{prefix}_{timestamp}.wav" | |
| # Initialize logger when module is imported | |
| logger = setup_logging() |