Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Ultra-fast Numba JIT-compiled implementation of multivariate Gaussian overlap calculation. | |
| This eliminates all Python overhead and runs at near-C speed. | |
| """ | |
| import numpy as np | |
| try: | |
| import numba | |
| from numba import jit, prange | |
| NUMBA_AVAILABLE = True | |
| except ImportError: | |
| NUMBA_AVAILABLE = False | |
| print("Warning: Numba not installed. Install with: pip install numba") | |
| if NUMBA_AVAILABLE: | |
| def compute_overlap_batch_numba(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12): | |
| """ | |
| ULTIMATE PERFORMANCE: "It's just differences, divisions, and exponentials!" | |
| Eliminates ALL overhead and just does the core mathematical operations: | |
| overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²))) | |
| Parameters: | |
| means1_batch: (n_subjects, 150, n_features) array of means for task 1 | |
| vars1_batch: (n_subjects, 150, n_features) array of variances for task 1 | |
| means2_batch: (n_subjects, 150, n_features) array of means for task 2 | |
| vars2_batch: (n_subjects, 150, n_features) array of variances for task 2 | |
| tol: Tolerance for variance validity | |
| Returns: | |
| overlap_batch: (n_subjects, 150, 150) array of overlap values | |
| """ | |
| n_subjects, n_phases, n_features = means1_batch.shape | |
| # Pre-allocate output | |
| overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64) | |
| # CRITICAL OPTIMIZATION: Pre-compute ALL validation outside the main loops | |
| # This eliminates millions of redundant NaN checks | |
| valid_phases1 = np.zeros((n_subjects, 150), dtype=numba.boolean) | |
| valid_phases2 = np.zeros((n_subjects, 150), dtype=numba.boolean) | |
| # Pre-compute phase validity for all subjects at once | |
| for s in prange(n_subjects): | |
| for i in range(150): | |
| # Check phase validity once per phase | |
| valid1 = True | |
| valid2 = True | |
| for f in range(n_features): | |
| if np.isnan(means1_batch[s, i, f]) or np.isnan(vars1_batch[s, i, f]): | |
| valid1 = False | |
| if np.isnan(means2_batch[s, i, f]) or np.isnan(vars2_batch[s, i, f]): | |
| valid2 = False | |
| valid_phases1[s, i] = valid1 | |
| valid_phases2[s, i] = valid2 | |
| # MAIN COMPUTATION: Process only valid phase pairs | |
| for s in prange(n_subjects): | |
| for i in range(150): | |
| if not valid_phases1[s, i]: | |
| continue | |
| # Extract data for phase i once (avoid repeated indexing) | |
| means1_i = means1_batch[s, i] | |
| vars1_i = vars1_batch[s, i] | |
| for j in range(150): | |
| if not valid_phases2[s, j]: | |
| continue | |
| # Extract data for phase j once | |
| means2_j = means2_batch[s, j] | |
| vars2_j = vars2_batch[s, j] | |
| # VECTORIZED CORE COMPUTATION - "It's just math!" | |
| # Calculate: sum((μ1 - μ2)² / (σ1² + σ2²)) | |
| # Step 1: Vector operations (no loops!) | |
| diff = means1_i - means2_j # Vector subtraction | |
| var_sum = vars1_i + vars2_j # Vector addition | |
| # Step 2: Check variance validity (vectorized) | |
| valid_variances = True | |
| for f in range(n_features): | |
| if var_sum[f] <= tol: | |
| valid_variances = False | |
| break | |
| if valid_variances: | |
| # Step 3: Quadratic form (vectorized) | |
| quad_terms = diff * diff / var_sum # Element-wise operations | |
| quad_sum = 0.0 | |
| for f in range(n_features): # Fast accumulation | |
| quad_sum += quad_terms[f] | |
| # Step 4: Exponential with underflow protection | |
| half_quad = 0.5 * quad_sum | |
| if half_quad <= 20.0: | |
| overlap_batch[s, i, j] = np.exp(-half_quad) | |
| return overlap_batch | |
| def compute_overlap_batch_numba_ultra_fast(means1_batch, vars1_batch, means2_batch, vars2_batch): | |
| """ | |
| ULTRA-FAST MODE: Skip ALL validation for clean data. | |
| This is the absolute fastest possible implementation - just pure math! | |
| Use ONLY when you're certain the data has no NaN values. | |
| Returns overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²))) | |
| """ | |
| n_subjects, n_phases, n_features = means1_batch.shape | |
| overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64) | |
| for s in prange(n_subjects): | |
| for i in range(150): | |
| means1_i = means1_batch[s, i] | |
| vars1_i = vars1_batch[s, i] | |
| for j in range(150): | |
| means2_j = means2_batch[s, j] | |
| vars2_j = vars2_batch[s, j] | |
| # Pure mathematical computation - no checks, no validation | |
| diff = means1_i - means2_j | |
| var_sum = vars1_i + vars2_j | |
| quad_terms = diff * diff / var_sum | |
| quad_sum = 0.0 | |
| for f in range(n_features): | |
| quad_sum += quad_terms[f] | |
| overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum) | |
| return overlap_batch | |
| def compute_overlap_batch_numba_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch): | |
| """ | |
| VECTORIZED MODE: Enhanced Numba with better vectorization. | |
| Processes entire rows at once to minimize inner loops and maximize cache efficiency. | |
| This is the enhanced version that "throws more in" while staying on CPU. | |
| """ | |
| n_subjects, n_phases, n_features = means1_batch.shape | |
| overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64) | |
| for s in prange(n_subjects): | |
| # Process entire row at once for better vectorization | |
| for i in range(150): | |
| means1_i = means1_batch[s, i] # Shape: (n_features,) | |
| vars1_i = vars1_batch[s, i] | |
| # OPTIMIZATION: Vectorize the inner j loop by processing all j at once | |
| # Create arrays for all phase_j comparisons | |
| for j in range(150): | |
| means2_j = means2_batch[s, j] | |
| vars2_j = vars2_batch[s, j] | |
| # Vectorized operations over features | |
| diff = means1_i - means2_j | |
| var_sum = vars1_i + vars2_j | |
| quad_terms = diff * diff / var_sum | |
| # Fast sum over features | |
| quad_sum = 0.0 | |
| for f in range(n_features): | |
| quad_sum += quad_terms[f] | |
| overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum) | |
| return overlap_batch | |
| def compute_overlap_batch_numba_row_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch): | |
| """ | |
| ROW-VECTORIZED MODE: Process entire rows of phase pairs at once. | |
| This minimizes the innermost loops by computing all j phases for each i. | |
| Better cache utilization and more vectorization opportunities. | |
| """ | |
| n_subjects, n_phases, n_features = means1_batch.shape | |
| overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64) | |
| for s in prange(n_subjects): | |
| for i in range(150): | |
| means1_i = means1_batch[s, i] # Current phase means (n_features,) | |
| vars1_i = vars1_batch[s, i] # Current phase variances | |
| # Process all j phases for this i in one go | |
| means2_all = means2_batch[s] # All phase means (150, n_features) | |
| vars2_all = vars2_batch[s] # All phase variances | |
| # Compute differences and sums for all j at once | |
| for j in range(150): | |
| # Fast vectorized computation over features | |
| quad_sum = 0.0 | |
| for f in range(n_features): | |
| diff_f = means1_i[f] - means2_all[j, f] | |
| var_sum_f = vars1_i[f] + vars2_all[j, f] | |
| quad_sum += diff_f * diff_f / var_sum_f | |
| overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum) | |
| return overlap_batch | |
| def apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch, | |
| means2_batch, vars2_batch, tol=1e-12): | |
| """ | |
| Apply biomechanical filtering in-place using Numba. | |
| This modifies the overlap_batch array directly for maximum efficiency. | |
| """ | |
| n_subjects = overlap_batch.shape[0] | |
| negligible_threshold = 0.1 | |
| ampable_threshold = 0.2 | |
| ci_factor = 1.96 | |
| for s in range(n_subjects): | |
| # Only process first feature (torque) for biomechanical filtering | |
| for i in range(150): | |
| mean1 = means1_batch[s, i, 0] | |
| var1 = vars1_batch[s, i, 0] | |
| if np.isnan(mean1) or np.isnan(var1): | |
| continue | |
| std1 = np.sqrt(var1) | |
| ci_lo1 = mean1 - ci_factor * std1 | |
| ci_hi1 = mean1 + ci_factor * std1 | |
| negligible1 = (ci_lo1 >= -negligible_threshold) and (ci_hi1 <= negligible_threshold) | |
| ampable1 = np.abs(mean1) > ampable_threshold | |
| for j in range(150): | |
| mean2 = means2_batch[s, j, 0] | |
| var2 = vars2_batch[s, j, 0] | |
| if np.isnan(mean2) or np.isnan(var2): | |
| continue | |
| std2 = np.sqrt(var2) | |
| ci_lo2 = mean2 - ci_factor * std2 | |
| ci_hi2 = mean2 + ci_factor * std2 | |
| negligible2 = (ci_lo2 >= -negligible_threshold) and (ci_hi2 <= negligible_threshold) | |
| ampable2 = np.abs(mean2) > ampable_threshold | |
| # Three-level filtering | |
| if negligible1 and negligible2: | |
| # Both negligible - set to 1 | |
| overlap_batch[s, i, j] = 1.0 | |
| elif (negligible1 and ampable2) or (negligible2 and ampable1): | |
| # Amplitude conflict - keep original | |
| pass | |
| else: | |
| # Sign reversal case - apply probability-based filtering | |
| std1_safe = max(std1, tol) | |
| std2_safe = max(std2, tol) | |
| # Normal CDF approximation (simplified for Numba) | |
| # Using a simple approximation since scipy.stats.norm is not available in nopython mode | |
| z1 = mean1 / std1_safe | |
| z2 = mean2 / std2_safe | |
| # Simple normal CDF approximation | |
| # This is less accurate but much faster and Numba-compatible | |
| def norm_cdf_approx(x): | |
| # Approximation of normal CDF | |
| t = 1.0 / (1.0 + 0.2316419 * np.abs(x)) | |
| d = 0.3989423 * np.exp(-x * x / 2.0) | |
| prob = d * t * (0.3193815 + t * (-0.3565638 + t * (1.781478 + t * (-1.821256 + t * 1.330274)))) | |
| if x > 0: | |
| return 1.0 - prob | |
| else: | |
| return prob | |
| Ppos1 = norm_cdf_approx(z1) | |
| Ppos2 = norm_cdf_approx(z2) | |
| # Sign-mismatch probability | |
| Pdiff_sign = Ppos1 * (1.0 - Ppos2) + (1.0 - Ppos1) * Ppos2 | |
| # Mean-difference penalty | |
| mean_diff = np.abs(mean1 - mean2) | |
| s_thresh = 0.2 | |
| e_thresh = 0.5 | |
| if mean_diff <= s_thresh: | |
| penalty = 0.0 | |
| elif mean_diff >= e_thresh: | |
| penalty = 1.0 | |
| else: | |
| # Linear ramp (simplified from sigmoid) | |
| penalty = (mean_diff - s_thresh) / (e_thresh - s_thresh) | |
| # Apply combined penalty | |
| Pdiff = max(Pdiff_sign, penalty) | |
| output_diff = 1.0 - overlap_batch[s, i, j] | |
| overlap_batch[s, i, j] = 1.0 - output_diff * Pdiff | |
| return overlap_batch | |
| def compute_overlap_batch_fallback(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12): | |
| """ | |
| Fallback implementation when Numba is not available. | |
| This is a simple, clean NumPy implementation without excessive overhead. | |
| """ | |
| n_subjects, n_phases, n_features = means1_batch.shape | |
| overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64) | |
| for s in range(n_subjects): | |
| # Pre-compute NaN masks for this subject | |
| has_nan1 = np.any(np.isnan(means1_batch[s]) | np.isnan(vars1_batch[s]), axis=1) | |
| has_nan2 = np.any(np.isnan(means2_batch[s]) | np.isnan(vars2_batch[s]), axis=1) | |
| for i in range(150): | |
| if has_nan1[i]: | |
| continue | |
| for j in range(150): | |
| if has_nan2[j]: | |
| continue | |
| # Direct computation | |
| diff = means1_batch[s, i] - means2_batch[s, j] | |
| sum_var = vars1_batch[s, i] + vars2_batch[s, j] | |
| # Check validity | |
| if np.all(sum_var > tol): | |
| quad_sum = np.sum(diff**2 / sum_var) | |
| half_quad = 0.5 * quad_sum | |
| if half_quad <= 20.0: | |
| overlap_batch[s, i, j] = np.exp(-half_quad) | |
| return overlap_batch | |
| # Main interface function | |
| def compute_overlap_batch(means1_batch, vars1_batch, means2_batch, vars2_batch, | |
| tol=1e-12, biomechanical_filter=False, ultra_fast=True, | |
| vectorized_mode='auto'): | |
| """ | |
| Main interface for computing batch overlap with multiple vectorization modes. | |
| Parameters: | |
| ultra_fast: bool - Default True for maximum speed | |
| vectorized_mode: str - 'auto', 'ultra_fast', 'vectorized', 'row_vectorized' | |
| """ | |
| if NUMBA_AVAILABLE: | |
| # Select best vectorization strategy | |
| if vectorized_mode == 'auto': | |
| # Auto-select based on data size | |
| n_subjects, _, n_features = means1_batch.shape | |
| if n_features >= 10 or n_subjects >= 15: | |
| mode = 'row_vectorized' # Best for larger feature sets | |
| elif n_features >= 4: | |
| mode = 'vectorized' # Good for medium feature sets | |
| else: | |
| mode = 'ultra_fast' # Simple and fast for small feature sets | |
| else: | |
| mode = vectorized_mode | |
| try: | |
| if mode == 'row_vectorized': | |
| overlap_batch = compute_overlap_batch_numba_row_vectorized( | |
| means1_batch, vars1_batch, means2_batch, vars2_batch) | |
| elif mode == 'vectorized': | |
| overlap_batch = compute_overlap_batch_numba_vectorized( | |
| means1_batch, vars1_batch, means2_batch, vars2_batch) | |
| else: # ultra_fast | |
| overlap_batch = compute_overlap_batch_numba_ultra_fast( | |
| means1_batch, vars1_batch, means2_batch, vars2_batch) | |
| except: | |
| # Fallback to validated version if any optimized version fails | |
| overlap_batch = compute_overlap_batch_numba(means1_batch, vars1_batch, | |
| means2_batch, vars2_batch, tol) | |
| if biomechanical_filter: | |
| overlap_batch = apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch, | |
| means2_batch, vars2_batch, tol) | |
| else: | |
| overlap_batch = compute_overlap_batch_fallback(means1_batch, vars1_batch, | |
| means2_batch, vars2_batch, tol) | |
| # Final clipping | |
| np.clip(overlap_batch, 0.0, 1.0, out=overlap_batch) | |
| return overlap_batch | |
| if __name__ == "__main__": | |
| # Simple test to verify it works | |
| print("Testing Numba overlap calculation...") | |
| # Create test data | |
| n_subjects = 10 | |
| n_features = 20 # e.g., 10 time windows × 2 sensors | |
| means1 = np.random.randn(n_subjects, 150, n_features) | |
| vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1 | |
| means2 = np.random.randn(n_subjects, 150, n_features) | |
| vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1 | |
| # Time the calculation | |
| import time | |
| print(f"Numba available: {NUMBA_AVAILABLE}") | |
| print(f"Computing overlap for {n_subjects} subjects, {n_features} features...") | |
| start = time.time() | |
| result = compute_overlap_batch(means1, vars1, means2, vars2) | |
| end = time.time() | |
| print(f"Result shape: {result.shape}") | |
| print(f"Execution time: {end - start:.3f} seconds") | |
| print(f"Non-zero elements: {np.count_nonzero(result)}") | |
| print(f"Max value: {np.max(result):.4f}") | |
| print(f"Min value: {np.min(result):.4f}") | |
| if NUMBA_AVAILABLE: | |
| print("\n✅ Numba JIT compilation successful!") | |
| print("The first run compiles the function, subsequent runs will be much faster.") | |
| # Run again to show compiled performance | |
| start = time.time() | |
| result = compute_overlap_batch(means1, vars1, means2, vars2) | |
| end = time.time() | |
| print(f"Compiled execution time: {end - start:.3f} seconds") | |
| else: | |
| print("\n⚠️ Numba not available, using fallback implementation.") | |
| print("Install Numba for 10-100x speedup: pip install numba") |