Lower-Limb-Similarity-Analysis / numba_overlap.py
jmontp's picture
Updated to new data and multivaraite lib api
43ec583
#!/usr/bin/env python3
"""
Ultra-fast Numba JIT-compiled implementation of multivariate Gaussian overlap calculation.
This eliminates all Python overhead and runs at near-C speed.
"""
import numpy as np
try:
import numba
from numba import jit, prange
NUMBA_AVAILABLE = True
except ImportError:
NUMBA_AVAILABLE = False
print("Warning: Numba not installed. Install with: pip install numba")
if NUMBA_AVAILABLE:
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
def compute_overlap_batch_numba(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
"""
ULTIMATE PERFORMANCE: "It's just differences, divisions, and exponentials!"
Eliminates ALL overhead and just does the core mathematical operations:
overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
Parameters:
means1_batch: (n_subjects, 150, n_features) array of means for task 1
vars1_batch: (n_subjects, 150, n_features) array of variances for task 1
means2_batch: (n_subjects, 150, n_features) array of means for task 2
vars2_batch: (n_subjects, 150, n_features) array of variances for task 2
tol: Tolerance for variance validity
Returns:
overlap_batch: (n_subjects, 150, 150) array of overlap values
"""
n_subjects, n_phases, n_features = means1_batch.shape
# Pre-allocate output
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
# CRITICAL OPTIMIZATION: Pre-compute ALL validation outside the main loops
# This eliminates millions of redundant NaN checks
valid_phases1 = np.zeros((n_subjects, 150), dtype=numba.boolean)
valid_phases2 = np.zeros((n_subjects, 150), dtype=numba.boolean)
# Pre-compute phase validity for all subjects at once
for s in prange(n_subjects):
for i in range(150):
# Check phase validity once per phase
valid1 = True
valid2 = True
for f in range(n_features):
if np.isnan(means1_batch[s, i, f]) or np.isnan(vars1_batch[s, i, f]):
valid1 = False
if np.isnan(means2_batch[s, i, f]) or np.isnan(vars2_batch[s, i, f]):
valid2 = False
valid_phases1[s, i] = valid1
valid_phases2[s, i] = valid2
# MAIN COMPUTATION: Process only valid phase pairs
for s in prange(n_subjects):
for i in range(150):
if not valid_phases1[s, i]:
continue
# Extract data for phase i once (avoid repeated indexing)
means1_i = means1_batch[s, i]
vars1_i = vars1_batch[s, i]
for j in range(150):
if not valid_phases2[s, j]:
continue
# Extract data for phase j once
means2_j = means2_batch[s, j]
vars2_j = vars2_batch[s, j]
# VECTORIZED CORE COMPUTATION - "It's just math!"
# Calculate: sum((μ1 - μ2)² / (σ1² + σ2²))
# Step 1: Vector operations (no loops!)
diff = means1_i - means2_j # Vector subtraction
var_sum = vars1_i + vars2_j # Vector addition
# Step 2: Check variance validity (vectorized)
valid_variances = True
for f in range(n_features):
if var_sum[f] <= tol:
valid_variances = False
break
if valid_variances:
# Step 3: Quadratic form (vectorized)
quad_terms = diff * diff / var_sum # Element-wise operations
quad_sum = 0.0
for f in range(n_features): # Fast accumulation
quad_sum += quad_terms[f]
# Step 4: Exponential with underflow protection
half_quad = 0.5 * quad_sum
if half_quad <= 20.0:
overlap_batch[s, i, j] = np.exp(-half_quad)
return overlap_batch
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
def compute_overlap_batch_numba_ultra_fast(means1_batch, vars1_batch, means2_batch, vars2_batch):
"""
ULTRA-FAST MODE: Skip ALL validation for clean data.
This is the absolute fastest possible implementation - just pure math!
Use ONLY when you're certain the data has no NaN values.
Returns overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
"""
n_subjects, n_phases, n_features = means1_batch.shape
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
for s in prange(n_subjects):
for i in range(150):
means1_i = means1_batch[s, i]
vars1_i = vars1_batch[s, i]
for j in range(150):
means2_j = means2_batch[s, j]
vars2_j = vars2_batch[s, j]
# Pure mathematical computation - no checks, no validation
diff = means1_i - means2_j
var_sum = vars1_i + vars2_j
quad_terms = diff * diff / var_sum
quad_sum = 0.0
for f in range(n_features):
quad_sum += quad_terms[f]
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
return overlap_batch
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
def compute_overlap_batch_numba_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
"""
VECTORIZED MODE: Enhanced Numba with better vectorization.
Processes entire rows at once to minimize inner loops and maximize cache efficiency.
This is the enhanced version that "throws more in" while staying on CPU.
"""
n_subjects, n_phases, n_features = means1_batch.shape
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
for s in prange(n_subjects):
# Process entire row at once for better vectorization
for i in range(150):
means1_i = means1_batch[s, i] # Shape: (n_features,)
vars1_i = vars1_batch[s, i]
# OPTIMIZATION: Vectorize the inner j loop by processing all j at once
# Create arrays for all phase_j comparisons
for j in range(150):
means2_j = means2_batch[s, j]
vars2_j = vars2_batch[s, j]
# Vectorized operations over features
diff = means1_i - means2_j
var_sum = vars1_i + vars2_j
quad_terms = diff * diff / var_sum
# Fast sum over features
quad_sum = 0.0
for f in range(n_features):
quad_sum += quad_terms[f]
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
return overlap_batch
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
def compute_overlap_batch_numba_row_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
"""
ROW-VECTORIZED MODE: Process entire rows of phase pairs at once.
This minimizes the innermost loops by computing all j phases for each i.
Better cache utilization and more vectorization opportunities.
"""
n_subjects, n_phases, n_features = means1_batch.shape
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
for s in prange(n_subjects):
for i in range(150):
means1_i = means1_batch[s, i] # Current phase means (n_features,)
vars1_i = vars1_batch[s, i] # Current phase variances
# Process all j phases for this i in one go
means2_all = means2_batch[s] # All phase means (150, n_features)
vars2_all = vars2_batch[s] # All phase variances
# Compute differences and sums for all j at once
for j in range(150):
# Fast vectorized computation over features
quad_sum = 0.0
for f in range(n_features):
diff_f = means1_i[f] - means2_all[j, f]
var_sum_f = vars1_i[f] + vars2_all[j, f]
quad_sum += diff_f * diff_f / var_sum_f
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
return overlap_batch
@jit(nopython=True, cache=True)
def apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
means2_batch, vars2_batch, tol=1e-12):
"""
Apply biomechanical filtering in-place using Numba.
This modifies the overlap_batch array directly for maximum efficiency.
"""
n_subjects = overlap_batch.shape[0]
negligible_threshold = 0.1
ampable_threshold = 0.2
ci_factor = 1.96
for s in range(n_subjects):
# Only process first feature (torque) for biomechanical filtering
for i in range(150):
mean1 = means1_batch[s, i, 0]
var1 = vars1_batch[s, i, 0]
if np.isnan(mean1) or np.isnan(var1):
continue
std1 = np.sqrt(var1)
ci_lo1 = mean1 - ci_factor * std1
ci_hi1 = mean1 + ci_factor * std1
negligible1 = (ci_lo1 >= -negligible_threshold) and (ci_hi1 <= negligible_threshold)
ampable1 = np.abs(mean1) > ampable_threshold
for j in range(150):
mean2 = means2_batch[s, j, 0]
var2 = vars2_batch[s, j, 0]
if np.isnan(mean2) or np.isnan(var2):
continue
std2 = np.sqrt(var2)
ci_lo2 = mean2 - ci_factor * std2
ci_hi2 = mean2 + ci_factor * std2
negligible2 = (ci_lo2 >= -negligible_threshold) and (ci_hi2 <= negligible_threshold)
ampable2 = np.abs(mean2) > ampable_threshold
# Three-level filtering
if negligible1 and negligible2:
# Both negligible - set to 1
overlap_batch[s, i, j] = 1.0
elif (negligible1 and ampable2) or (negligible2 and ampable1):
# Amplitude conflict - keep original
pass
else:
# Sign reversal case - apply probability-based filtering
std1_safe = max(std1, tol)
std2_safe = max(std2, tol)
# Normal CDF approximation (simplified for Numba)
# Using a simple approximation since scipy.stats.norm is not available in nopython mode
z1 = mean1 / std1_safe
z2 = mean2 / std2_safe
# Simple normal CDF approximation
# This is less accurate but much faster and Numba-compatible
def norm_cdf_approx(x):
# Approximation of normal CDF
t = 1.0 / (1.0 + 0.2316419 * np.abs(x))
d = 0.3989423 * np.exp(-x * x / 2.0)
prob = d * t * (0.3193815 + t * (-0.3565638 + t * (1.781478 + t * (-1.821256 + t * 1.330274))))
if x > 0:
return 1.0 - prob
else:
return prob
Ppos1 = norm_cdf_approx(z1)
Ppos2 = norm_cdf_approx(z2)
# Sign-mismatch probability
Pdiff_sign = Ppos1 * (1.0 - Ppos2) + (1.0 - Ppos1) * Ppos2
# Mean-difference penalty
mean_diff = np.abs(mean1 - mean2)
s_thresh = 0.2
e_thresh = 0.5
if mean_diff <= s_thresh:
penalty = 0.0
elif mean_diff >= e_thresh:
penalty = 1.0
else:
# Linear ramp (simplified from sigmoid)
penalty = (mean_diff - s_thresh) / (e_thresh - s_thresh)
# Apply combined penalty
Pdiff = max(Pdiff_sign, penalty)
output_diff = 1.0 - overlap_batch[s, i, j]
overlap_batch[s, i, j] = 1.0 - output_diff * Pdiff
return overlap_batch
def compute_overlap_batch_fallback(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
"""
Fallback implementation when Numba is not available.
This is a simple, clean NumPy implementation without excessive overhead.
"""
n_subjects, n_phases, n_features = means1_batch.shape
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
for s in range(n_subjects):
# Pre-compute NaN masks for this subject
has_nan1 = np.any(np.isnan(means1_batch[s]) | np.isnan(vars1_batch[s]), axis=1)
has_nan2 = np.any(np.isnan(means2_batch[s]) | np.isnan(vars2_batch[s]), axis=1)
for i in range(150):
if has_nan1[i]:
continue
for j in range(150):
if has_nan2[j]:
continue
# Direct computation
diff = means1_batch[s, i] - means2_batch[s, j]
sum_var = vars1_batch[s, i] + vars2_batch[s, j]
# Check validity
if np.all(sum_var > tol):
quad_sum = np.sum(diff**2 / sum_var)
half_quad = 0.5 * quad_sum
if half_quad <= 20.0:
overlap_batch[s, i, j] = np.exp(-half_quad)
return overlap_batch
# Main interface function
def compute_overlap_batch(means1_batch, vars1_batch, means2_batch, vars2_batch,
tol=1e-12, biomechanical_filter=False, ultra_fast=True,
vectorized_mode='auto'):
"""
Main interface for computing batch overlap with multiple vectorization modes.
Parameters:
ultra_fast: bool - Default True for maximum speed
vectorized_mode: str - 'auto', 'ultra_fast', 'vectorized', 'row_vectorized'
"""
if NUMBA_AVAILABLE:
# Select best vectorization strategy
if vectorized_mode == 'auto':
# Auto-select based on data size
n_subjects, _, n_features = means1_batch.shape
if n_features >= 10 or n_subjects >= 15:
mode = 'row_vectorized' # Best for larger feature sets
elif n_features >= 4:
mode = 'vectorized' # Good for medium feature sets
else:
mode = 'ultra_fast' # Simple and fast for small feature sets
else:
mode = vectorized_mode
try:
if mode == 'row_vectorized':
overlap_batch = compute_overlap_batch_numba_row_vectorized(
means1_batch, vars1_batch, means2_batch, vars2_batch)
elif mode == 'vectorized':
overlap_batch = compute_overlap_batch_numba_vectorized(
means1_batch, vars1_batch, means2_batch, vars2_batch)
else: # ultra_fast
overlap_batch = compute_overlap_batch_numba_ultra_fast(
means1_batch, vars1_batch, means2_batch, vars2_batch)
except:
# Fallback to validated version if any optimized version fails
overlap_batch = compute_overlap_batch_numba(means1_batch, vars1_batch,
means2_batch, vars2_batch, tol)
if biomechanical_filter:
overlap_batch = apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
means2_batch, vars2_batch, tol)
else:
overlap_batch = compute_overlap_batch_fallback(means1_batch, vars1_batch,
means2_batch, vars2_batch, tol)
# Final clipping
np.clip(overlap_batch, 0.0, 1.0, out=overlap_batch)
return overlap_batch
if __name__ == "__main__":
# Simple test to verify it works
print("Testing Numba overlap calculation...")
# Create test data
n_subjects = 10
n_features = 20 # e.g., 10 time windows × 2 sensors
means1 = np.random.randn(n_subjects, 150, n_features)
vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
means2 = np.random.randn(n_subjects, 150, n_features)
vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
# Time the calculation
import time
print(f"Numba available: {NUMBA_AVAILABLE}")
print(f"Computing overlap for {n_subjects} subjects, {n_features} features...")
start = time.time()
result = compute_overlap_batch(means1, vars1, means2, vars2)
end = time.time()
print(f"Result shape: {result.shape}")
print(f"Execution time: {end - start:.3f} seconds")
print(f"Non-zero elements: {np.count_nonzero(result)}")
print(f"Max value: {np.max(result):.4f}")
print(f"Min value: {np.min(result):.4f}")
if NUMBA_AVAILABLE:
print("\n✅ Numba JIT compilation successful!")
print("The first run compiles the function, subsequent runs will be much faster.")
# Run again to show compiled performance
start = time.time()
result = compute_overlap_batch(means1, vars1, means2, vars2)
end = time.time()
print(f"Compiled execution time: {end - start:.3f} seconds")
else:
print("\n⚠️ Numba not available, using fallback implementation.")
print("Install Numba for 10-100x speedup: pip install numba")