jmontp commited on
Commit
43ec583
·
1 Parent(s): d1a6270

Updated to new data and multivaraite lib api

Browse files
__pycache__/shared_styling.cpython-312.pyc DELETED
Binary file (32.5 kB)
 
cached_data/precalculated_stats.pkl.gz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:030a03b477c02ead69f0b0c83bfab2133d7f8bb9e2f81ab3ae09cb21b27fbd93
3
- size 5573257
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be3355088c6913cdb552802f2e3aa139abed2a3a79441895d31d08ff5f9cd23a
3
+ size 25443689
config.py CHANGED
@@ -1,18 +1,22 @@
1
  # Configuration file for dashboard and preprocessing scripts
2
 
 
 
 
 
3
  AVAILABLE_SENSORS = [
 
4
  'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
5
  'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
6
  'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
7
  'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l',
8
- # Add segment angles and velocities
9
- 'pelvis_angle_s_r', 'pelvis_angle_s_l', 'pelvis_vel_s_r', 'pelvis_vel_s_l',
10
- 'hip_angle_f_r', 'hip_angle_f_l', 'hip_vel_f_r', 'hip_vel_f_l',
11
- 'knee_angle_f_r', 'knee_angle_f_l', 'knee_vel_f_r', 'knee_vel_f_l',
12
- 'ankle_angle_f_r', 'ankle_angle_f_l', 'ankle_vel_f_r', 'ankle_vel_f_l',
13
- 'hip_angle_t_r', 'hip_angle_t_l', 'hip_vel_t_r', 'hip_vel_t_l',
14
- 'knee_angle_t_r', 'knee_angle_t_l', 'knee_vel_t_r', 'knee_vel_t_l',
15
- 'ankle_angle_t_r', 'ankle_angle_t_l', 'ankle_vel_t_r', 'ankle_vel_t_l'
16
  ]
17
 
18
  AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
@@ -26,24 +30,63 @@ ANALYSIS_ABSTRACTION_LEVELS = ['High', 'Medium/Low']
26
 
27
  # Task configurations for pre-calculation and analysis
28
  LOW_LEVEL_TASKS = [
29
- ('stair_descent', None, None),
30
- ('stair_ascent', None, None),
31
- ('sit_to_stand', None, None),
 
 
 
 
 
 
 
 
 
 
 
32
  ('level_walking', 0.0, 0.8),
33
  ('level_walking', 0.0, 1.0),
34
  ('level_walking', 0.0, 1.2),
35
- # Gtech variants
 
 
 
36
  ('level_walking', 0.0, 0.6),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ('level_walking', 0.0, 1.8),
 
 
 
38
  ('incline_walking', 5.0, 0.8),
39
  ('incline_walking', 5.0, 1.0),
40
  ('incline_walking', 5.0, 1.2),
41
- ('decline_walking', -5.0, 0.8),
42
- ('decline_walking', -5.0, 1.0),
43
- ('decline_walking', -5.0, 1.2),
44
  ('incline_walking', 10.0, 0.8),
45
  ('incline_walking', 10.0, 1.0),
46
  ('incline_walking', 10.0, 1.2),
 
 
 
 
 
47
  ('decline_walking', -10.0, 0.8),
48
  ('decline_walking', -10.0, 1.0),
49
  ('decline_walking', -10.0, 1.2),
 
1
  # Configuration file for dashboard and preprocessing scripts
2
 
3
+ # Data source control
4
+ USE_FILTERED_DATA = True # If True: use *_filtered.parquet files (standard speeds only)
5
+ # If False: use original .parquet files (all speeds, more data)
6
+
7
  AVAILABLE_SENSORS = [
8
+ # Core sagittal plane sensors (available in both UMich and GTech)
9
  'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
10
  'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
11
  'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
12
  'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l',
13
+ # Pelvis and trunk angles (available in both datasets)
14
+ 'pelvis_angle_s_r', 'pelvis_angle_s_l',
15
+ # Calculated segment sensors (computed by loader.py)
16
+ 'shank_angle_s_r', 'shank_angle_s_l', 'shank_vel_s_r', 'shank_vel_s_l',
17
+ 'thigh_angle_s_r', 'thigh_angle_s_l', 'thigh_vel_s_r', 'thigh_vel_s_l',
18
+ # Ground reaction forces (keeping only vertical right GRF)
19
+ 'grf_y_r',
 
20
  ]
21
 
22
  AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
 
30
 
31
  # Task configurations for pre-calculation and analysis
32
  LOW_LEVEL_TASKS = [
33
+ # Stair tasks - using incline field for height in mm, speed is None (self-selected)
34
+ # Original GTech stairs (102mm)
35
+ ('stair_descent', 102.0, None),
36
+ ('stair_ascent', 102.0, None),
37
+
38
+ # Additional stair heights (new datasets)
39
+ ('stair_descent', 127.0, None),
40
+ ('stair_ascent', 127.0, None),
41
+ ('stair_descent', 152.0, None),
42
+ ('stair_ascent', 152.0, None),
43
+ ('stair_descent', 178.0, None),
44
+ ('stair_ascent', 178.0, None),
45
+
46
+ # Level walking - UMich steady-state speeds
47
  ('level_walking', 0.0, 0.8),
48
  ('level_walking', 0.0, 1.0),
49
  ('level_walking', 0.0, 1.2),
50
+
51
+ # Level walking - All GTech speeds
52
+ ('level_walking', 0.0, 0.5),
53
+ ('level_walking', 0.0, 0.55),
54
  ('level_walking', 0.0, 0.6),
55
+ ('level_walking', 0.0, 0.65),
56
+ ('level_walking', 0.0, 0.7),
57
+ ('level_walking', 0.0, 0.75),
58
+ ('level_walking', 0.0, 0.85),
59
+ ('level_walking', 0.0, 0.9),
60
+ ('level_walking', 0.0, 0.95),
61
+ ('level_walking', 0.0, 1.05),
62
+ ('level_walking', 0.0, 1.1),
63
+ ('level_walking', 0.0, 1.15),
64
+ ('level_walking', 0.0, 1.25),
65
+ ('level_walking', 0.0, 1.3),
66
+ ('level_walking', 0.0, 1.35),
67
+ ('level_walking', 0.0, 1.4),
68
+ ('level_walking', 0.0, 1.45),
69
+ ('level_walking', 0.0, 1.5),
70
+ ('level_walking', 0.0, 1.55),
71
+ ('level_walking', 0.0, 1.6),
72
+ ('level_walking', 0.0, 1.65),
73
+ ('level_walking', 0.0, 1.7),
74
+ ('level_walking', 0.0, 1.75),
75
  ('level_walking', 0.0, 1.8),
76
+ ('level_walking', 0.0, 1.85),
77
+
78
+ # Incline walking - UMich speeds
79
  ('incline_walking', 5.0, 0.8),
80
  ('incline_walking', 5.0, 1.0),
81
  ('incline_walking', 5.0, 1.2),
 
 
 
82
  ('incline_walking', 10.0, 0.8),
83
  ('incline_walking', 10.0, 1.0),
84
  ('incline_walking', 10.0, 1.2),
85
+
86
+ # Decline walking - UMich speeds
87
+ ('decline_walking', -5.0, 0.8),
88
+ ('decline_walking', -5.0, 1.0),
89
+ ('decline_walking', -5.0, 1.2),
90
  ('decline_walking', -10.0, 0.8),
91
  ('decline_walking', -10.0, 1.0),
92
  ('decline_walking', -10.0, 1.2),
gpu_overlap.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GPU-accelerated implementation of multivariate Gaussian overlap calculation using CuPy.
4
+ This provides massive speedup for large-scale analyses by processing all task pairs simultaneously.
5
+ """
6
+
7
+ import numpy as np
8
+ import warnings
9
+ from typing import Optional
10
+
11
+ # Try to import CuPy for GPU acceleration
12
+ try:
13
+ import cupy as cp
14
+ GPU_AVAILABLE = True
15
+ print("✅ CuPy GPU acceleration available")
16
+ except ImportError:
17
+ GPU_AVAILABLE = False
18
+ cp = None
19
+ print("⚠️ CuPy not available. Install with: pip install cupy-cuda12x")
20
+
21
+ # Check for CUDA availability
22
+ if GPU_AVAILABLE:
23
+ try:
24
+ # Test if CUDA is actually available
25
+ device = cp.cuda.Device(0)
26
+ GPU_READY = True
27
+ print(f"🚀 GPU ready: Device {device.id} (RTX detected)")
28
+ except:
29
+ GPU_READY = False
30
+ GPU_AVAILABLE = False
31
+ print("⚠️ CUDA not available, disabling GPU acceleration")
32
+ else:
33
+ GPU_READY = False
34
+
35
+
36
+ def compute_overlap_batch_gpu(means1_batch, vars1_batch, means2_batch, vars2_batch,
37
+ tol=1e-12, biomechanical_filter=False):
38
+ """
39
+ GPU-accelerated batch overlap computation using CuPy.
40
+
41
+ Processes all subjects simultaneously with full GPU vectorization.
42
+ This is the "throw everything in" approach for maximum GPU utilization.
43
+
44
+ Parameters:
45
+ means1_batch: np.ndarray shape (n_subjects, 150, n_features)
46
+ vars1_batch: np.ndarray shape (n_subjects, 150, n_features)
47
+ means2_batch: np.ndarray shape (n_subjects, 150, n_features)
48
+ vars2_batch: np.ndarray shape (n_subjects, 150, n_features)
49
+ tol: float, tolerance for variance validity
50
+ biomechanical_filter: bool, apply biomechanical filtering
51
+
52
+ Returns:
53
+ np.ndarray shape (n_subjects, 150, 150) - overlap values
54
+ """
55
+ if not GPU_AVAILABLE:
56
+ raise RuntimeError("CuPy not available for GPU computation")
57
+
58
+ n_subjects, n_phases, n_features = means1_batch.shape
59
+
60
+ # Transfer to GPU - single transfer for all data
61
+ means1_gpu = cp.asarray(means1_batch, dtype=cp.float32)
62
+ vars1_gpu = cp.asarray(vars1_batch, dtype=cp.float32)
63
+ means2_gpu = cp.asarray(means2_batch, dtype=cp.float32)
64
+ vars2_gpu = cp.asarray(vars2_batch, dtype=cp.float32)
65
+
66
+ # Pre-allocate output on GPU
67
+ overlap_batch_gpu = cp.zeros((n_subjects, 150, 150), dtype=cp.float32)
68
+
69
+ # CRITICAL OPTIMIZATION: Use broadcasting to compute ALL phase pairs at once
70
+ # Shape transformations for broadcasting:
71
+ # means1: (n_subjects, 150, 1, n_features) - for phase_i
72
+ # means2: (n_subjects, 1, 150, n_features) - for phase_j
73
+ # Result: (n_subjects, 150, 150, n_features) - all pairs
74
+
75
+ means1_exp = means1_gpu[:, :, cp.newaxis, :] # (n_subjects, 150, 1, n_features)
76
+ vars1_exp = vars1_gpu[:, :, cp.newaxis, :]
77
+ means2_exp = means2_gpu[:, cp.newaxis, :, :] # (n_subjects, 1, 150, n_features)
78
+ vars2_exp = vars2_gpu[:, cp.newaxis, :, :]
79
+
80
+ # Compute all differences and variance sums simultaneously
81
+ diff = means1_exp - means2_exp # Shape: (n_subjects, 150, 150, n_features)
82
+ var_sum = vars1_exp + vars2_exp # Shape: (n_subjects, 150, 150, n_features)
83
+
84
+ # NaN handling: Create validity mask
85
+ valid_mask = (~cp.isnan(diff).any(axis=3) &
86
+ ~cp.isnan(var_sum).any(axis=3) &
87
+ (var_sum > tol).all(axis=3)) # Shape: (n_subjects, 150, 150)
88
+
89
+ # Compute quadratic form for valid entries only
90
+ # Use where to avoid division by zero
91
+ quad_terms = cp.where(valid_mask[:, :, :, cp.newaxis],
92
+ diff * diff / var_sum,
93
+ 0.0) # Shape: (n_subjects, 150, 150, n_features)
94
+
95
+ # Sum over features
96
+ quad_sum = cp.sum(quad_terms, axis=3) # Shape: (n_subjects, 150, 150)
97
+
98
+ # Apply exponential with underflow protection
99
+ # Only compute exp for valid entries with reasonable values
100
+ safe_exp_mask = valid_mask & (quad_sum * 0.5 <= 20.0)
101
+
102
+ overlap_batch_gpu = cp.where(safe_exp_mask,
103
+ cp.exp(-0.5 * quad_sum),
104
+ 0.0)
105
+
106
+ # Apply biomechanical filtering if requested
107
+ if biomechanical_filter:
108
+ overlap_batch_gpu = _apply_biomechanical_filter_gpu(
109
+ overlap_batch_gpu, means1_gpu, vars1_gpu, means2_gpu, vars2_gpu, tol
110
+ )
111
+
112
+ # Transfer back to CPU - single transfer
113
+ result = cp.asnumpy(overlap_batch_gpu).astype(np.float64)
114
+
115
+ # Final clipping on CPU
116
+ np.clip(result, 0.0, 1.0, out=result)
117
+ return result
118
+
119
+
120
+ def _apply_biomechanical_filter_gpu(overlap_batch, means1_batch, vars1_batch,
121
+ means2_batch, vars2_batch, tol):
122
+ """Apply biomechanical filtering on GPU using vectorized operations."""
123
+ n_subjects = overlap_batch.shape[0]
124
+ negligible_threshold = 0.1
125
+ ampable_threshold = 0.2
126
+ ci_factor = 1.96
127
+
128
+ # Only process first feature (torque) for biomechanical filtering
129
+ means1_torque = means1_batch[:, :, 0] # Shape: (n_subjects, 150)
130
+ means2_torque = means2_batch[:, :, 0]
131
+ vars1_torque = vars1_batch[:, :, 0]
132
+ vars2_torque = vars2_batch[:, :, 0]
133
+
134
+ # Vectorized std and CI calculations
135
+ std1 = cp.sqrt(vars1_torque)
136
+ std2 = cp.sqrt(vars2_torque)
137
+
138
+ ci_lo1 = means1_torque - ci_factor * std1
139
+ ci_hi1 = means1_torque + ci_factor * std1
140
+ ci_lo2 = means2_torque - ci_factor * std2
141
+ ci_hi2 = means2_torque + ci_factor * std2
142
+
143
+ # Vectorized mask computation
144
+ negligible1 = ((ci_lo1 >= -negligible_threshold) &
145
+ (ci_hi1 <= negligible_threshold)) # Shape: (n_subjects, 150)
146
+ negligible2 = ((ci_lo2 >= -negligible_threshold) &
147
+ (ci_hi2 <= negligible_threshold))
148
+ ampable1 = cp.abs(means1_torque) > ampable_threshold
149
+ ampable2 = cp.abs(means2_torque) > ampable_threshold
150
+
151
+ # Broadcast to phase pair dimensions using newaxis
152
+ neg1_exp = negligible1[:, :, cp.newaxis] # (n_subjects, 150, 1)
153
+ neg2_exp = negligible2[:, cp.newaxis, :] # (n_subjects, 1, 150)
154
+ amp1_exp = ampable1[:, :, cp.newaxis]
155
+ amp2_exp = ampable2[:, cp.newaxis, :]
156
+
157
+ # Three-level filtering masks
158
+ # Negligible-negligible: Both torques are negligible
159
+ m0 = neg1_exp & neg2_exp # Shape: (n_subjects, 150, 150)
160
+
161
+ # Amplitude conflicts: One negligible, other ampable
162
+ m1 = ((neg1_exp & amp2_exp) | (neg2_exp & amp1_exp))
163
+
164
+ # Sign reversal cases: Neither negligible-negligible nor amplitude conflict
165
+ m2 = ~(m0 | m1)
166
+
167
+ # Apply negligible-negligible rule (set to 1.0)
168
+ overlap_batch = cp.where(m0, 1.0, overlap_batch)
169
+
170
+ # Apply sign reversal filtering for m2 cases
171
+ if cp.any(m2):
172
+ # Get indices where filtering is needed
173
+ s_idx, i_idx, j_idx = cp.where(m2)
174
+
175
+ if len(s_idx) > 0:
176
+ # Vectorized probability calculations
177
+ std1_safe = cp.maximum(std1, tol)
178
+ std2_safe = cp.maximum(std2, tol)
179
+ z1 = means1_torque / std1_safe
180
+ z2 = means2_torque / std2_safe
181
+
182
+ # Normal CDF approximation (vectorized)
183
+ def norm_cdf_gpu(x):
184
+ # Abramowitz and Stegun approximation
185
+ t = 1.0 / (1.0 + 0.2316419 * cp.abs(x))
186
+ d = 0.3989423 * cp.exp(-x * x / 2.0)
187
+ prob = d * t * (0.3193815 + t * (-0.3565638 +
188
+ t * (1.781478 + t * (-1.821256 + t * 1.330274))))
189
+ return cp.where(x > 0, 1.0 - prob, prob)
190
+
191
+ Ppos1 = norm_cdf_gpu(z1)
192
+ Ppos2 = norm_cdf_gpu(z2)
193
+
194
+ # Sign-mismatch probability for selected indices
195
+ Pdiff_sign = (Ppos1[s_idx, i_idx] * (1.0 - Ppos2[s_idx, j_idx]) +
196
+ (1.0 - Ppos1[s_idx, i_idx]) * Ppos2[s_idx, j_idx])
197
+
198
+ # Mean-difference penalty (vectorized ramp function)
199
+ mean_diff = cp.abs(means1_torque[s_idx, i_idx] - means2_torque[s_idx, j_idx])
200
+ s_thresh, e_thresh = 0.2, 0.5
201
+
202
+ # Linear ramp penalty
203
+ penalty = cp.clip((mean_diff - s_thresh) / (e_thresh - s_thresh), 0.0, 1.0)
204
+
205
+ # Combine penalties
206
+ Pdiff = cp.maximum(Pdiff_sign, penalty)
207
+
208
+ # Apply penalty to overlaps
209
+ current_overlaps = overlap_batch[s_idx, i_idx, j_idx]
210
+ output_diff = 1.0 - current_overlaps
211
+ scaled_output_diff = output_diff * Pdiff
212
+ overlap_batch[s_idx, i_idx, j_idx] = 1.0 - scaled_output_diff
213
+
214
+ return overlap_batch
215
+
216
+
217
+ def compute_overlap_batch_gpu_chunked(means1_batch, vars1_batch, means2_batch, vars2_batch,
218
+ chunk_size=None, **kwargs):
219
+ """
220
+ Chunked GPU processing for very large datasets that don't fit in GPU memory.
221
+
222
+ Automatically determines optimal chunk size based on available GPU memory.
223
+ """
224
+ if not GPU_AVAILABLE:
225
+ raise RuntimeError("CuPy not available for GPU computation")
226
+
227
+ n_subjects = means1_batch.shape[0]
228
+
229
+ if chunk_size is None:
230
+ # Estimate chunk size based on GPU memory
231
+ mempool = cp.get_default_memory_pool()
232
+ available_memory = mempool.free_bytes()
233
+
234
+ # Rough estimate: each subject needs ~150*150*4 bytes for overlap + input arrays
235
+ bytes_per_subject = 150 * 150 * 4 * 6 # 6 arrays (means1, vars1, means2, vars2, overlap, temp)
236
+ estimated_chunk_size = max(1, int(available_memory * 0.8 // bytes_per_subject))
237
+ chunk_size = min(estimated_chunk_size, n_subjects)
238
+
239
+ print(f"🔧 Auto-determined GPU chunk size: {chunk_size} subjects")
240
+
241
+ if chunk_size >= n_subjects:
242
+ # Process all at once
243
+ return compute_overlap_batch_gpu(means1_batch, vars1_batch,
244
+ means2_batch, vars2_batch, **kwargs)
245
+
246
+ # Process in chunks
247
+ results = []
248
+ for i in range(0, n_subjects, chunk_size):
249
+ end_idx = min(i + chunk_size, n_subjects)
250
+
251
+ chunk_result = compute_overlap_batch_gpu(
252
+ means1_batch[i:end_idx],
253
+ vars1_batch[i:end_idx],
254
+ means2_batch[i:end_idx],
255
+ vars2_batch[i:end_idx],
256
+ **kwargs
257
+ )
258
+ results.append(chunk_result)
259
+
260
+ return np.concatenate(results, axis=0)
261
+
262
+
263
+ def benchmark_gpu_vs_cpu():
264
+ """Benchmark GPU vs CPU performance on sample data."""
265
+ if not GPU_AVAILABLE:
266
+ print("GPU not available for benchmarking")
267
+ return
268
+
269
+ import time
270
+
271
+ # Create test data
272
+ n_subjects = 10
273
+ n_features = 4
274
+
275
+ print(f"🔧 Benchmarking with {n_subjects} subjects, {n_features} features...")
276
+
277
+ means1 = np.random.randn(n_subjects, 150, n_features).astype(np.float32)
278
+ vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)).astype(np.float32) + 0.1
279
+ means2 = np.random.randn(n_subjects, 150, n_features).astype(np.float32)
280
+ vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)).astype(np.float32) + 0.1
281
+
282
+ # Warm up GPU
283
+ if GPU_AVAILABLE:
284
+ _ = compute_overlap_batch_gpu(means1[:2], vars1[:2], means2[:2], vars2[:2])
285
+
286
+ # Benchmark GPU
287
+ if GPU_AVAILABLE:
288
+ start = time.time()
289
+ result_gpu = compute_overlap_batch_gpu(means1, vars1, means2, vars2)
290
+ gpu_time = time.time() - start
291
+ print(f"🚀 GPU time: {gpu_time:.4f} seconds")
292
+ else:
293
+ result_gpu = None
294
+ gpu_time = float('inf')
295
+
296
+ # Benchmark CPU (Numba fallback)
297
+ try:
298
+ from .numba_overlap import compute_overlap_batch
299
+ start = time.time()
300
+ result_cpu = compute_overlap_batch(means1, vars1, means2, vars2)
301
+ cpu_time = time.time() - start
302
+ print(f"🔧 CPU time: {cpu_time:.4f} seconds")
303
+
304
+ if GPU_AVAILABLE and result_gpu is not None:
305
+ speedup = cpu_time / gpu_time
306
+ print(f"📈 GPU Speedup: {speedup:.1f}x")
307
+
308
+ # Check accuracy
309
+ max_diff = np.max(np.abs(result_gpu.astype(np.float64) - result_cpu))
310
+ print(f"🎯 Max difference: {max_diff:.2e}")
311
+
312
+ except ImportError:
313
+ print("❌ Numba not available for CPU comparison")
314
+
315
+
316
+ def compute_overlap_batch_gpu_mega(all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
317
+ valid_mask, tol=1e-12, biomechanical_filter=False):
318
+ """
319
+ MEGA-BATCH GPU computation: Process ALL task pairs simultaneously.
320
+
321
+ This is the ultimate "throw everything in" approach for maximum GPU utilization.
322
+ Processes hundreds of task pairs × subjects × phase pairs in a single GPU call.
323
+
324
+ Parameters:
325
+ all_means1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
326
+ all_vars1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
327
+ all_means2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
328
+ all_vars2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
329
+ valid_mask: np.ndarray shape (n_task_pairs, n_subjects_max) - bool mask for valid subjects
330
+ tol: float, tolerance for variance validity
331
+ biomechanical_filter: bool, apply biomechanical filtering
332
+
333
+ Returns:
334
+ np.ndarray shape (n_task_pairs, n_subjects_max, 150, 150) - overlap values
335
+ """
336
+ if not GPU_AVAILABLE:
337
+ raise RuntimeError("CuPy not available for mega-batch GPU computation")
338
+
339
+ n_task_pairs, n_subjects_max, n_phases, n_features = all_means1_batch.shape
340
+
341
+ print(f"🚀 GPU Mega-batch: Processing {n_task_pairs} task pairs × {n_subjects_max} subjects × {150*150} phase pairs")
342
+ print(f"📊 Total computations: {n_task_pairs * n_subjects_max * 150 * 150:,}")
343
+
344
+ # Transfer ALL data to GPU in single transfer
345
+ means1_gpu = cp.asarray(all_means1_batch, dtype=cp.float32)
346
+ vars1_gpu = cp.asarray(all_vars1_batch, dtype=cp.float32)
347
+ means2_gpu = cp.asarray(all_means2_batch, dtype=cp.float32)
348
+ vars2_gpu = cp.asarray(all_vars2_batch, dtype=cp.float32)
349
+ valid_gpu = cp.asarray(valid_mask, dtype=cp.bool_)
350
+
351
+ # Pre-allocate output on GPU
352
+ overlap_batch_gpu = cp.zeros((n_task_pairs, n_subjects_max, 150, 150), dtype=cp.float32)
353
+
354
+ # MEGA BROADCASTING: Process ALL task pairs and subjects simultaneously
355
+ # Shape transformations for 5D broadcasting:
356
+ # (n_task_pairs, n_subjects_max, 150, 1, n_features) vs (n_task_pairs, n_subjects_max, 1, 150, n_features)
357
+
358
+ means1_exp = means1_gpu[:, :, :, cp.newaxis, :] # Add phase_j dimension
359
+ vars1_exp = vars1_gpu[:, :, :, cp.newaxis, :]
360
+ means2_exp = means2_gpu[:, :, cp.newaxis, :, :] # Add phase_i dimension
361
+ vars2_exp = vars2_gpu[:, :, cp.newaxis, :, :]
362
+
363
+ # Compute ALL differences and variance sums simultaneously
364
+ # Shape: (n_task_pairs, n_subjects_max, 150, 150, n_features)
365
+ diff = means1_exp - means2_exp
366
+ var_sum = vars1_exp + vars2_exp
367
+
368
+ # Create mega validity mask
369
+ # Shape: (n_task_pairs, n_subjects_max, 150, 150)
370
+ subject_valid = valid_gpu[:, :, cp.newaxis, cp.newaxis] # Broadcast to all phase pairs
371
+
372
+ # NaN and variance validity for ALL data simultaneously
373
+ nan_valid = (~cp.isnan(diff).any(axis=4) &
374
+ ~cp.isnan(var_sum).any(axis=4) &
375
+ (var_sum > tol).all(axis=4))
376
+
377
+ # Combined validity mask
378
+ full_valid_mask = subject_valid & nan_valid
379
+
380
+ # Compute quadratic form for ALL valid entries
381
+ quad_terms = cp.where(full_valid_mask[:, :, :, :, cp.newaxis],
382
+ diff * diff / var_sum,
383
+ 0.0)
384
+
385
+ # Sum over features for ALL task pairs simultaneously
386
+ quad_sum = cp.sum(quad_terms, axis=4) # Shape: (n_task_pairs, n_subjects_max, 150, 150)
387
+
388
+ # Apply exponential with underflow protection
389
+ safe_exp_mask = full_valid_mask & (quad_sum * 0.5 <= 20.0)
390
+ overlap_batch_gpu = cp.where(safe_exp_mask,
391
+ cp.exp(-0.5 * quad_sum),
392
+ 0.0)
393
+
394
+ # Apply biomechanical filtering if requested
395
+ if biomechanical_filter:
396
+ overlap_batch_gpu = _apply_biomechanical_filter_gpu_mega(
397
+ overlap_batch_gpu, means1_gpu, vars1_gpu, means2_gpu, vars2_gpu, valid_gpu, tol
398
+ )
399
+
400
+ # Transfer back to CPU - single transfer for ALL results
401
+ print("📥 Transferring results from GPU...")
402
+ result = cp.asnumpy(overlap_batch_gpu).astype(np.float64)
403
+
404
+ # Final clipping
405
+ np.clip(result, 0.0, 1.0, out=result)
406
+
407
+ print(f"✅ Mega-batch GPU computation complete!")
408
+ return result
409
+
410
+
411
+ def _apply_biomechanical_filter_gpu_mega(overlap_batch, means1_batch, vars1_batch,
412
+ means2_batch, vars2_batch, valid_mask, tol):
413
+ """Apply biomechanical filtering for mega-batch on GPU."""
414
+ negligible_threshold = 0.1
415
+ ampable_threshold = 0.2
416
+ ci_factor = 1.96
417
+
418
+ n_task_pairs, n_subjects_max = overlap_batch.shape[:2]
419
+
420
+ # Only process first feature (torque) for biomechanical filtering
421
+ means1_torque = means1_batch[:, :, :, 0] # Shape: (n_task_pairs, n_subjects_max, 150)
422
+ means2_torque = means2_batch[:, :, :, 0]
423
+ vars1_torque = vars1_batch[:, :, :, 0]
424
+ vars2_torque = vars2_batch[:, :, :, 0]
425
+
426
+ # Vectorized std and CI calculations for ALL task pairs
427
+ std1 = cp.sqrt(vars1_torque)
428
+ std2 = cp.sqrt(vars2_torque)
429
+
430
+ ci_lo1 = means1_torque - ci_factor * std1
431
+ ci_hi1 = means1_torque + ci_factor * std1
432
+ ci_lo2 = means2_torque - ci_factor * std2
433
+ ci_hi2 = means2_torque + ci_factor * std2
434
+
435
+ # Vectorized mask computation for ALL task pairs
436
+ negligible1 = ((ci_lo1 >= -negligible_threshold) &
437
+ (ci_hi1 <= negligible_threshold))
438
+ negligible2 = ((ci_lo2 >= -negligible_threshold) &
439
+ (ci_hi2 <= negligible_threshold))
440
+ ampable1 = cp.abs(means1_torque) > ampable_threshold
441
+ ampable2 = cp.abs(means2_torque) > ampable_threshold
442
+
443
+ # Broadcast to phase pair dimensions
444
+ # Shape: (n_task_pairs, n_subjects_max, 150, 1)
445
+ neg1_exp = negligible1[:, :, :, cp.newaxis]
446
+ amp1_exp = ampable1[:, :, :, cp.newaxis]
447
+
448
+ # Shape: (n_task_pairs, n_subjects_max, 1, 150)
449
+ neg2_exp = negligible2[:, :, cp.newaxis, :]
450
+ amp2_exp = ampable2[:, :, cp.newaxis, :]
451
+
452
+ # Apply subject validity mask
453
+ valid_exp = valid_mask[:, :, cp.newaxis, cp.newaxis]
454
+
455
+ # Three-level filtering masks for ALL task pairs
456
+ m0 = (neg1_exp & neg2_exp) & valid_exp # Negligible-negligible
457
+ m1 = ((neg1_exp & amp2_exp) | (neg2_exp & amp1_exp)) & valid_exp # Amplitude conflicts
458
+ m2 = ~(m0 | m1) & valid_exp # Sign reversal cases
459
+
460
+ # Apply negligible-negligible rule
461
+ overlap_batch = cp.where(m0, 1.0, overlap_batch)
462
+
463
+ # Apply sign reversal filtering for m2 cases (if any exist)
464
+ if cp.any(m2):
465
+ # For mega-batch, we'll use a simplified linear ramp for performance
466
+ # (Full probability calculation would be too expensive for this scale)
467
+
468
+ # Get phase indices for m2 cases
469
+ t_idx, s_idx, i_idx, j_idx = cp.where(m2)
470
+
471
+ if len(t_idx) > 0:
472
+ # Mean-difference penalty (vectorized)
473
+ mean_diff = cp.abs(means1_torque[t_idx, s_idx, i_idx] -
474
+ means2_torque[t_idx, s_idx, j_idx])
475
+
476
+ # Linear ramp penalty (simplified for mega-batch performance)
477
+ s_thresh, e_thresh = 0.2, 0.5
478
+ penalty = cp.clip((mean_diff - s_thresh) / (e_thresh - s_thresh), 0.0, 1.0)
479
+
480
+ # Apply penalty to overlaps
481
+ current_overlaps = overlap_batch[t_idx, s_idx, i_idx, j_idx]
482
+ output_diff = 1.0 - current_overlaps
483
+ scaled_output_diff = output_diff * penalty
484
+ overlap_batch[t_idx, s_idx, i_idx, j_idx] = 1.0 - scaled_output_diff
485
+
486
+ return overlap_batch
487
+
488
+
489
+ def estimate_mega_batch_memory(n_task_pairs, n_subjects_max, n_features):
490
+ """
491
+ Estimate GPU memory requirements for mega-batch processing.
492
+
493
+ CRITICAL: This accounts for the 5D broadcasting that happens during GPU computation:
494
+ - Input: (n_task_pairs, n_subjects_max, 150, n_features)
495
+ - Broadcast to: (n_task_pairs, n_subjects_max, 150, 150, n_features) for computation
496
+ - The 150x150 expansion is the killer for large feature counts!
497
+ """
498
+ # Input arrays (pre-broadcasting)
499
+ input_size = 4 * n_task_pairs * n_subjects_max * 150 * n_features * 4 # 4 input arrays
500
+
501
+ # Output array
502
+ output_size = n_task_pairs * n_subjects_max * 150 * 150 * 4
503
+
504
+ # CRITICAL: 5D broadcasting intermediate tensors during computation
505
+ # These are the real memory hogs: (n_task_pairs, n_subjects_max, 150, 150, n_features)
506
+ broadcast_5d_size = n_task_pairs * n_subjects_max * 150 * 150 * n_features * 4
507
+ # We need multiple of these simultaneously (diff, var_sum, quad_terms, etc.)
508
+ intermediate_5d_size = broadcast_5d_size * 4 # Conservative estimate: 4 large 5D tensors
509
+
510
+ total_bytes = input_size + output_size + intermediate_5d_size
511
+ total_gb = total_bytes / (1024**3)
512
+
513
+ return total_gb
514
+
515
+
516
+ def get_available_gpu_memory_gb():
517
+ """Get available GPU memory in GB."""
518
+ if not GPU_AVAILABLE:
519
+ return 0.0
520
+
521
+ try:
522
+ # Get GPU memory info directly from CuPy device
523
+ device = cp.cuda.Device()
524
+ total_mem = device.mem_info[1] # Total memory
525
+ used_mem = device.mem_info[1] - device.mem_info[0] # Used = Total - Free
526
+
527
+ # Use 70% of free memory as safety margin
528
+ free_mem = device.mem_info[0] * 0.7
529
+ available_gb = free_mem / (1024**3)
530
+
531
+ return max(0.5, available_gb) # Ensure at least 0.5GB for minimal chunking
532
+ except:
533
+ # Fallback: assume 5GB available for RTX series
534
+ return 5.0
535
+
536
+
537
+ def calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, target_memory_gb=None):
538
+ """Calculate optimal chunk size based on available GPU memory."""
539
+ if not GPU_AVAILABLE:
540
+ return 1
541
+
542
+ if target_memory_gb is None:
543
+ target_memory_gb = get_available_gpu_memory_gb()
544
+
545
+ # Binary search for optimal chunk size
546
+ min_chunk = 1
547
+ max_chunk = total_pairs
548
+ optimal_chunk = 1
549
+
550
+ while min_chunk <= max_chunk:
551
+ mid_chunk = (min_chunk + max_chunk) // 2
552
+ memory_needed = estimate_mega_batch_memory(mid_chunk, n_subjects_max, n_features)
553
+
554
+ if memory_needed <= target_memory_gb:
555
+ optimal_chunk = mid_chunk
556
+ min_chunk = mid_chunk + 1
557
+ else:
558
+ max_chunk = mid_chunk - 1
559
+
560
+ # Ensure at least 1 task pair per chunk
561
+ return max(1, optimal_chunk)
562
+
563
+
564
+ def get_available_ram_gb():
565
+ """Get available system RAM in GB."""
566
+ try:
567
+ import psutil
568
+ available_ram_gb = psutil.virtual_memory().available / (1024**3)
569
+ return available_ram_gb
570
+ except ImportError:
571
+ # Fallback: assume 16GB available (conservative)
572
+ return 16.0
573
+
574
+
575
+ def calculate_ram_max_chunk_size(n_subjects_max, n_features, available_ram_gb):
576
+ """Calculate maximum chunk size based on available RAM for numpy arrays."""
577
+ # Each chunk needs 4 arrays: all_means1, all_vars1, all_means2, all_vars2
578
+ # Shape per array: (chunk_size, n_subjects_max, 150, n_features)
579
+ # Each element: 4 bytes (float32)
580
+ bytes_per_task_pair = 4 * n_subjects_max * 150 * n_features * 4 # 4 arrays × 4 bytes
581
+
582
+ # Use 70% of available RAM as safety margin
583
+ safe_ram_bytes = available_ram_gb * 0.7 * (1024**3)
584
+
585
+ max_chunk_size = int(safe_ram_bytes / bytes_per_task_pair)
586
+ return max(1, max_chunk_size)
587
+
588
+
589
+ def calculate_optimal_chunk_size_dual_constraint(total_pairs, n_subjects_max, n_features):
590
+ """
591
+ Calculate optimal chunk size considering BOTH GPU memory and system RAM constraints.
592
+
593
+ This prevents out-of-memory errors by respecting both:
594
+ 1. GPU memory limits (for CuPy processing)
595
+ 2. System RAM limits (for numpy array allocation)
596
+
597
+ CRITICAL: For very large feature counts (>100), the 5D broadcasting becomes
598
+ prohibitively expensive, so we use much more conservative estimates.
599
+
600
+ Returns the minimum chunk size that satisfies both constraints.
601
+ """
602
+ if not GPU_AVAILABLE:
603
+ return 1
604
+
605
+ # Get available memory for both constraints
606
+ gpu_memory_gb = get_available_gpu_memory_gb()
607
+ ram_memory_gb = get_available_ram_gb()
608
+
609
+ # CRITICAL: For large feature counts, the 5D broadcasting dominates memory usage
610
+ # We need to be much more conservative
611
+ if n_features > 100:
612
+ print(f"⚠️ Large feature count ({n_features}) detected - using conservative chunking")
613
+ # For large features, memory usage scales roughly with features^2 due to broadcasting
614
+ # Use a much smaller base and scale down aggressively
615
+ feature_penalty = (n_features / 100) ** 1.5 # Exponential penalty
616
+ conservative_gpu_memory = gpu_memory_gb / feature_penalty
617
+ conservative_ram_memory = ram_memory_gb / (feature_penalty * 0.5) # RAM less affected
618
+
619
+ gpu_max_chunk = calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, conservative_gpu_memory)
620
+ ram_max_chunk = calculate_ram_max_chunk_size(n_subjects_max, n_features, conservative_ram_memory)
621
+ else:
622
+ # Normal calculation for reasonable feature counts
623
+ gpu_max_chunk = calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, gpu_memory_gb)
624
+ ram_max_chunk = calculate_ram_max_chunk_size(n_subjects_max, n_features, ram_memory_gb)
625
+
626
+ # Use the most restrictive constraint
627
+ optimal_chunk = min(gpu_max_chunk, ram_max_chunk, total_pairs)
628
+
629
+ print(f"🔧 Dual-constraint analysis:")
630
+ print(f" GPU memory: {gpu_memory_gb:.2f} GB → max {gpu_max_chunk} pairs")
631
+ print(f" RAM memory: {ram_memory_gb:.2f} GB → max {ram_max_chunk} pairs")
632
+ print(f" Using most restrictive: {optimal_chunk} pairs per chunk")
633
+
634
+ # For very large feature counts, ensure we don't go too high
635
+ if n_features > 100:
636
+ # Cap at a reasonable maximum for large feature counts
637
+ max_safe_chunk = max(1, int(50000 / n_features)) # Rough heuristic
638
+ optimal_chunk = min(optimal_chunk, max_safe_chunk)
639
+ if optimal_chunk == max_safe_chunk:
640
+ print(f" 🔒 Capped at {optimal_chunk} pairs due to large feature count")
641
+
642
+ return max(1, optimal_chunk)
643
+
644
+
645
+ def compute_overlap_batch_gpu_mega_chunked(all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
646
+ valid_mask, tol=1e-12, biomechanical_filter=False, progress_callback=None):
647
+ """
648
+ Chunked mega-batch GPU computation: Process task pairs in optimal chunks.
649
+
650
+ Automatically determines chunk size based on available GPU memory and processes
651
+ task pairs in chunks while maintaining all subjects per chunk for maximum efficiency.
652
+
653
+ Parameters:
654
+ all_means1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
655
+ all_vars1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
656
+ all_means2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
657
+ all_vars2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
658
+ valid_mask: np.ndarray shape (n_task_pairs, n_subjects_max) - bool mask for valid subjects
659
+ tol: float, tolerance for variance validity
660
+ biomechanical_filter: bool, apply biomechanical filtering
661
+ progress_callback: callable, progress reporting function
662
+
663
+ Returns:
664
+ np.ndarray shape (n_task_pairs, n_subjects_max, 150, 150) - overlap values
665
+ """
666
+ if not GPU_AVAILABLE:
667
+ raise RuntimeError("CuPy not available for chunked mega-batch GPU computation")
668
+
669
+ n_task_pairs, n_subjects_max, n_phases, n_features = all_means1_batch.shape
670
+
671
+ # Calculate optimal chunk size using dual constraints (GPU + RAM)
672
+ chunk_size = calculate_optimal_chunk_size_dual_constraint(n_task_pairs, n_subjects_max, n_features)
673
+
674
+ print(f"🔧 Chunking Strategy:")
675
+ print(f" Total task pairs: {n_task_pairs:,}")
676
+ print(f" Optimal chunk size: {chunk_size:,} task pairs")
677
+ print(f" Number of chunks: {(n_task_pairs + chunk_size - 1) // chunk_size}")
678
+
679
+ # Try single batch first, but catch out-of-memory errors
680
+ if chunk_size >= n_task_pairs:
681
+ print("🚀 Attempting single mega-batch processing...")
682
+ try:
683
+ return compute_overlap_batch_gpu_mega(
684
+ all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
685
+ valid_mask, tol, biomechanical_filter
686
+ )
687
+ except Exception as e:
688
+ if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
689
+ print(f"⚠️ Single batch failed with memory error, forcing chunking...")
690
+ # Recalculate with much more conservative memory estimate
691
+ conservative_memory = min(available_memory * 0.3, 3.0) # Use max 3GB or 30% of available
692
+ chunk_size = calculate_optimal_chunk_size(n_task_pairs, n_subjects_max, n_features, conservative_memory)
693
+ chunk_size = max(1, chunk_size // 2) # Further reduce chunk size
694
+ print(f"🔧 Fallback chunk size: {chunk_size} pairs (conservative estimate)")
695
+ else:
696
+ raise e
697
+
698
+ # Process in chunks
699
+ print(f"🔄 Processing {n_task_pairs:,} task pairs in chunks of {chunk_size:,}...")
700
+ results = []
701
+
702
+ for chunk_start in range(0, n_task_pairs, chunk_size):
703
+ chunk_end = min(chunk_start + chunk_size, n_task_pairs)
704
+ chunk_num = len(results) + 1
705
+ total_chunks = (n_task_pairs + chunk_size - 1) // chunk_size
706
+
707
+ print(f"🚀 Processing chunk {chunk_num}/{total_chunks} (task pairs {chunk_start}:{chunk_end})...")
708
+
709
+ # Extract chunk data
710
+ chunk_means1 = all_means1_batch[chunk_start:chunk_end]
711
+ chunk_vars1 = all_vars1_batch[chunk_start:chunk_end]
712
+ chunk_means2 = all_means2_batch[chunk_start:chunk_end]
713
+ chunk_vars2 = all_vars2_batch[chunk_start:chunk_end]
714
+ chunk_valid = valid_mask[chunk_start:chunk_end]
715
+
716
+ # Process chunk with additional error handling
717
+ import time
718
+ start_time = time.time()
719
+ try:
720
+ chunk_result = compute_overlap_batch_gpu_mega(
721
+ chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
722
+ chunk_valid, tol, biomechanical_filter
723
+ )
724
+ chunk_time = time.time() - start_time
725
+ except Exception as e:
726
+ if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
727
+ print(f" ⚠️ Chunk {chunk_num} still too large, attempting progressive reduction...")
728
+ # Progressive reduction: try smaller and smaller chunks
729
+ chunk_result = _process_chunk_with_progressive_reduction(
730
+ chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
731
+ chunk_valid, tol, biomechanical_filter, chunk_num
732
+ )
733
+ chunk_time = time.time() - start_time
734
+ else:
735
+ raise e
736
+
737
+ results.append(chunk_result)
738
+
739
+ # Progress reporting
740
+ progress = (chunk_end) / n_task_pairs
741
+ if progress_callback:
742
+ progress_callback(progress * 0.9) # Save 10% for final aggregation
743
+
744
+ # Performance metrics
745
+ chunk_pairs = chunk_end - chunk_start
746
+ valid_computations = np.sum(chunk_valid) * 150 * 150
747
+ throughput = valid_computations / chunk_time if chunk_time > 0 else 0
748
+
749
+ print(f" ✅ Chunk {chunk_num} complete: {chunk_time:.2f}s, {throughput:,.0f} computations/sec")
750
+
751
+ # Memory cleanup
752
+ if GPU_AVAILABLE:
753
+ cp.get_default_memory_pool().free_all_blocks()
754
+
755
+ print("🔧 Combining chunk results...")
756
+ final_result = np.concatenate(results, axis=0)
757
+
758
+ if progress_callback:
759
+ progress_callback(1.0)
760
+
761
+ print(f"✅ Chunked mega-batch processing complete!")
762
+ print(f"📊 Final result shape: {final_result.shape}")
763
+
764
+ return final_result
765
+
766
+
767
+ def _process_chunk_with_progressive_reduction(chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
768
+ chunk_valid, tol, biomechanical_filter, chunk_num):
769
+ """
770
+ Process a chunk with progressive size reduction if out-of-memory errors occur.
771
+
772
+ Tries progressively smaller sub-chunks until successful or reaches minimum size.
773
+ """
774
+ chunk_size = chunk_means1.shape[0]
775
+
776
+ # Try progressively smaller sub-chunks: 50%, 25%, 12.5%, etc.
777
+ reduction_factors = [0.5, 0.25, 0.125, 0.0625] # Down to 1/16th
778
+
779
+ for factor in reduction_factors:
780
+ sub_chunk_size = max(1, int(chunk_size * factor))
781
+ print(f" 🔄 Trying sub-chunk size: {sub_chunk_size} pairs ({factor*100:.1f}% of original)")
782
+
783
+ try:
784
+ # Process the chunk in sub-chunks
785
+ sub_results = []
786
+ for start_idx in range(0, chunk_size, sub_chunk_size):
787
+ end_idx = min(start_idx + sub_chunk_size, chunk_size)
788
+
789
+ sub_result = compute_overlap_batch_gpu_mega(
790
+ chunk_means1[start_idx:end_idx],
791
+ chunk_vars1[start_idx:end_idx],
792
+ chunk_means2[start_idx:end_idx],
793
+ chunk_vars2[start_idx:end_idx],
794
+ chunk_valid[start_idx:end_idx],
795
+ tol, biomechanical_filter
796
+ )
797
+ sub_results.append(sub_result)
798
+
799
+ # Clear GPU memory between sub-chunks
800
+ if GPU_AVAILABLE:
801
+ cp.get_default_memory_pool().free_all_blocks()
802
+
803
+ # Combine all sub-results
804
+ final_result = np.concatenate(sub_results, axis=0)
805
+ print(f" ✅ Progressive reduction successful with {sub_chunk_size}-pair sub-chunks")
806
+ return final_result
807
+
808
+ except Exception as e:
809
+ if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
810
+ print(f" ❌ Sub-chunk size {sub_chunk_size} still too large")
811
+ continue
812
+ else:
813
+ raise e
814
+
815
+ # If all reduction attempts failed, we need to fall back to sequential processing
816
+ # Processing one pair at a time with GPU overhead is actually slower than CPU
817
+ print(f" ❌ All reduction attempts failed - chunk too large for GPU mega-batch")
818
+ print(f" 💡 Recommendation: Use smaller time windows or switch to sequential processing")
819
+ print(f" 🔄 Falling back to CPU-based processing for this chunk...")
820
+
821
+ # Fall back to CPU processing for this chunk
822
+ try:
823
+ from .numba_overlap import compute_overlap_batch_numba_ultra_fast
824
+
825
+ # Process on CPU using Numba (much faster than single GPU pairs)
826
+ cpu_results = []
827
+ for i in range(chunk_size):
828
+ means1_i = chunk_means1[i] # Shape: (n_subjects, 150, n_features)
829
+ vars1_i = chunk_vars1[i]
830
+ means2_i = chunk_means2[i]
831
+ vars2_i = chunk_vars2[i]
832
+ valid_i = chunk_valid[i] # Shape: (n_subjects,)
833
+
834
+ # Process valid subjects only
835
+ valid_indices = np.where(valid_i)[0]
836
+ if len(valid_indices) > 0:
837
+ cpu_result = compute_overlap_batch_numba_ultra_fast(
838
+ means1_i[valid_indices], vars1_i[valid_indices],
839
+ means2_i[valid_indices], vars2_i[valid_indices]
840
+ )
841
+
842
+ # Reshape to expected format
843
+ full_result = np.zeros((1, chunk_valid.shape[1], 150, 150), dtype=np.float32)
844
+ full_result[0, valid_indices] = cpu_result
845
+ cpu_results.append(full_result)
846
+ else:
847
+ # No valid subjects
848
+ empty_result = np.zeros((1, chunk_valid.shape[1], 150, 150), dtype=np.float32)
849
+ cpu_results.append(empty_result)
850
+
851
+ final_result = np.concatenate(cpu_results, axis=0)
852
+ print(f" ✅ CPU fallback processing completed")
853
+ return final_result
854
+
855
+ except ImportError:
856
+ print(f" ❌ CPU fallback not available - creating zero results")
857
+ # Last resort: return zeros
858
+ final_result = np.zeros((chunk_size, chunk_valid.shape[1], 150, 150), dtype=np.float32)
859
+ return final_result
860
+
861
+
862
+ if __name__ == "__main__":
863
+ print("🧪 Testing GPU overlap calculation...")
864
+
865
+ if GPU_AVAILABLE:
866
+ benchmark_gpu_vs_cpu()
867
+
868
+ # Test mega-batch functionality
869
+ print("\n🚀 Testing mega-batch functionality...")
870
+
871
+ # Create test data for multiple task pairs
872
+ n_task_pairs = 5
873
+ n_subjects_max = 3
874
+ n_features = 4
875
+
876
+ all_means1 = np.random.randn(n_task_pairs, n_subjects_max, 150, n_features).astype(np.float32)
877
+ all_vars1 = np.abs(np.random.randn(n_task_pairs, n_subjects_max, 150, n_features)).astype(np.float32) + 0.1
878
+ all_means2 = np.random.randn(n_task_pairs, n_subjects_max, 150, n_features).astype(np.float32)
879
+ all_vars2 = np.abs(np.random.randn(n_task_pairs, n_subjects_max, 150, n_features)).astype(np.float32) + 0.1
880
+ valid_mask = np.ones((n_task_pairs, n_subjects_max), dtype=bool)
881
+
882
+ import time
883
+ start = time.time()
884
+ result = compute_overlap_batch_gpu_mega(all_means1, all_vars1, all_means2, all_vars2, valid_mask)
885
+ end = time.time()
886
+
887
+ print(f"✅ Mega-batch result shape: {result.shape}")
888
+ print(f"⏱️ Mega-batch time: {end - start:.4f}s")
889
+ print(f"📊 Throughput: {n_task_pairs * n_subjects_max * 150 * 150 / (end - start):,.0f} computations/sec")
890
+
891
+ else:
892
+ print("❌ GPU testing requires CuPy and CUDA")
multivariate_gaussian_overlap.py CHANGED
The diff for this file is too large to render. See raw diff
 
numba_overlap.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Ultra-fast Numba JIT-compiled implementation of multivariate Gaussian overlap calculation.
4
+ This eliminates all Python overhead and runs at near-C speed.
5
+ """
6
+
7
+ import numpy as np
8
+ try:
9
+ import numba
10
+ from numba import jit, prange
11
+ NUMBA_AVAILABLE = True
12
+ except ImportError:
13
+ NUMBA_AVAILABLE = False
14
+ print("Warning: Numba not installed. Install with: pip install numba")
15
+
16
+ if NUMBA_AVAILABLE:
17
+ @jit(nopython=True, parallel=True, cache=True, fastmath=True)
18
+ def compute_overlap_batch_numba(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
19
+ """
20
+ ULTIMATE PERFORMANCE: "It's just differences, divisions, and exponentials!"
21
+
22
+ Eliminates ALL overhead and just does the core mathematical operations:
23
+ overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
24
+
25
+ Parameters:
26
+ means1_batch: (n_subjects, 150, n_features) array of means for task 1
27
+ vars1_batch: (n_subjects, 150, n_features) array of variances for task 1
28
+ means2_batch: (n_subjects, 150, n_features) array of means for task 2
29
+ vars2_batch: (n_subjects, 150, n_features) array of variances for task 2
30
+ tol: Tolerance for variance validity
31
+
32
+ Returns:
33
+ overlap_batch: (n_subjects, 150, 150) array of overlap values
34
+ """
35
+ n_subjects, n_phases, n_features = means1_batch.shape
36
+
37
+ # Pre-allocate output
38
+ overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
39
+
40
+ # CRITICAL OPTIMIZATION: Pre-compute ALL validation outside the main loops
41
+ # This eliminates millions of redundant NaN checks
42
+ valid_phases1 = np.zeros((n_subjects, 150), dtype=numba.boolean)
43
+ valid_phases2 = np.zeros((n_subjects, 150), dtype=numba.boolean)
44
+
45
+ # Pre-compute phase validity for all subjects at once
46
+ for s in prange(n_subjects):
47
+ for i in range(150):
48
+ # Check phase validity once per phase
49
+ valid1 = True
50
+ valid2 = True
51
+
52
+ for f in range(n_features):
53
+ if np.isnan(means1_batch[s, i, f]) or np.isnan(vars1_batch[s, i, f]):
54
+ valid1 = False
55
+ if np.isnan(means2_batch[s, i, f]) or np.isnan(vars2_batch[s, i, f]):
56
+ valid2 = False
57
+
58
+ valid_phases1[s, i] = valid1
59
+ valid_phases2[s, i] = valid2
60
+
61
+ # MAIN COMPUTATION: Process only valid phase pairs
62
+ for s in prange(n_subjects):
63
+ for i in range(150):
64
+ if not valid_phases1[s, i]:
65
+ continue
66
+
67
+ # Extract data for phase i once (avoid repeated indexing)
68
+ means1_i = means1_batch[s, i]
69
+ vars1_i = vars1_batch[s, i]
70
+
71
+ for j in range(150):
72
+ if not valid_phases2[s, j]:
73
+ continue
74
+
75
+ # Extract data for phase j once
76
+ means2_j = means2_batch[s, j]
77
+ vars2_j = vars2_batch[s, j]
78
+
79
+ # VECTORIZED CORE COMPUTATION - "It's just math!"
80
+ # Calculate: sum((μ1 - μ2)² / (σ1² + σ2²))
81
+
82
+ # Step 1: Vector operations (no loops!)
83
+ diff = means1_i - means2_j # Vector subtraction
84
+ var_sum = vars1_i + vars2_j # Vector addition
85
+
86
+ # Step 2: Check variance validity (vectorized)
87
+ valid_variances = True
88
+ for f in range(n_features):
89
+ if var_sum[f] <= tol:
90
+ valid_variances = False
91
+ break
92
+
93
+ if valid_variances:
94
+ # Step 3: Quadratic form (vectorized)
95
+ quad_terms = diff * diff / var_sum # Element-wise operations
96
+ quad_sum = 0.0
97
+ for f in range(n_features): # Fast accumulation
98
+ quad_sum += quad_terms[f]
99
+
100
+ # Step 4: Exponential with underflow protection
101
+ half_quad = 0.5 * quad_sum
102
+ if half_quad <= 20.0:
103
+ overlap_batch[s, i, j] = np.exp(-half_quad)
104
+
105
+ return overlap_batch
106
+
107
+
108
+ @jit(nopython=True, parallel=True, cache=True, fastmath=True)
109
+ def compute_overlap_batch_numba_ultra_fast(means1_batch, vars1_batch, means2_batch, vars2_batch):
110
+ """
111
+ ULTRA-FAST MODE: Skip ALL validation for clean data.
112
+
113
+ This is the absolute fastest possible implementation - just pure math!
114
+ Use ONLY when you're certain the data has no NaN values.
115
+
116
+ Returns overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
117
+ """
118
+ n_subjects, n_phases, n_features = means1_batch.shape
119
+ overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
120
+
121
+ for s in prange(n_subjects):
122
+ for i in range(150):
123
+ means1_i = means1_batch[s, i]
124
+ vars1_i = vars1_batch[s, i]
125
+
126
+ for j in range(150):
127
+ means2_j = means2_batch[s, j]
128
+ vars2_j = vars2_batch[s, j]
129
+
130
+ # Pure mathematical computation - no checks, no validation
131
+ diff = means1_i - means2_j
132
+ var_sum = vars1_i + vars2_j
133
+ quad_terms = diff * diff / var_sum
134
+
135
+ quad_sum = 0.0
136
+ for f in range(n_features):
137
+ quad_sum += quad_terms[f]
138
+
139
+ overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
140
+
141
+ return overlap_batch
142
+
143
+
144
+ @jit(nopython=True, parallel=True, cache=True, fastmath=True)
145
+ def compute_overlap_batch_numba_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
146
+ """
147
+ VECTORIZED MODE: Enhanced Numba with better vectorization.
148
+
149
+ Processes entire rows at once to minimize inner loops and maximize cache efficiency.
150
+ This is the enhanced version that "throws more in" while staying on CPU.
151
+ """
152
+ n_subjects, n_phases, n_features = means1_batch.shape
153
+ overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
154
+
155
+ for s in prange(n_subjects):
156
+ # Process entire row at once for better vectorization
157
+ for i in range(150):
158
+ means1_i = means1_batch[s, i] # Shape: (n_features,)
159
+ vars1_i = vars1_batch[s, i]
160
+
161
+ # OPTIMIZATION: Vectorize the inner j loop by processing all j at once
162
+ # Create arrays for all phase_j comparisons
163
+ for j in range(150):
164
+ means2_j = means2_batch[s, j]
165
+ vars2_j = vars2_batch[s, j]
166
+
167
+ # Vectorized operations over features
168
+ diff = means1_i - means2_j
169
+ var_sum = vars1_i + vars2_j
170
+ quad_terms = diff * diff / var_sum
171
+
172
+ # Fast sum over features
173
+ quad_sum = 0.0
174
+ for f in range(n_features):
175
+ quad_sum += quad_terms[f]
176
+
177
+ overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
178
+
179
+ return overlap_batch
180
+
181
+
182
+ @jit(nopython=True, parallel=True, cache=True, fastmath=True)
183
+ def compute_overlap_batch_numba_row_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
184
+ """
185
+ ROW-VECTORIZED MODE: Process entire rows of phase pairs at once.
186
+
187
+ This minimizes the innermost loops by computing all j phases for each i.
188
+ Better cache utilization and more vectorization opportunities.
189
+ """
190
+ n_subjects, n_phases, n_features = means1_batch.shape
191
+ overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
192
+
193
+ for s in prange(n_subjects):
194
+ for i in range(150):
195
+ means1_i = means1_batch[s, i] # Current phase means (n_features,)
196
+ vars1_i = vars1_batch[s, i] # Current phase variances
197
+
198
+ # Process all j phases for this i in one go
199
+ means2_all = means2_batch[s] # All phase means (150, n_features)
200
+ vars2_all = vars2_batch[s] # All phase variances
201
+
202
+ # Compute differences and sums for all j at once
203
+ for j in range(150):
204
+ # Fast vectorized computation over features
205
+ quad_sum = 0.0
206
+ for f in range(n_features):
207
+ diff_f = means1_i[f] - means2_all[j, f]
208
+ var_sum_f = vars1_i[f] + vars2_all[j, f]
209
+ quad_sum += diff_f * diff_f / var_sum_f
210
+
211
+ overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
212
+
213
+ return overlap_batch
214
+
215
+
216
+ @jit(nopython=True, cache=True)
217
+ def apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
218
+ means2_batch, vars2_batch, tol=1e-12):
219
+ """
220
+ Apply biomechanical filtering in-place using Numba.
221
+
222
+ This modifies the overlap_batch array directly for maximum efficiency.
223
+ """
224
+ n_subjects = overlap_batch.shape[0]
225
+ negligible_threshold = 0.1
226
+ ampable_threshold = 0.2
227
+ ci_factor = 1.96
228
+
229
+ for s in range(n_subjects):
230
+ # Only process first feature (torque) for biomechanical filtering
231
+ for i in range(150):
232
+ mean1 = means1_batch[s, i, 0]
233
+ var1 = vars1_batch[s, i, 0]
234
+
235
+ if np.isnan(mean1) or np.isnan(var1):
236
+ continue
237
+
238
+ std1 = np.sqrt(var1)
239
+ ci_lo1 = mean1 - ci_factor * std1
240
+ ci_hi1 = mean1 + ci_factor * std1
241
+
242
+ negligible1 = (ci_lo1 >= -negligible_threshold) and (ci_hi1 <= negligible_threshold)
243
+ ampable1 = np.abs(mean1) > ampable_threshold
244
+
245
+ for j in range(150):
246
+ mean2 = means2_batch[s, j, 0]
247
+ var2 = vars2_batch[s, j, 0]
248
+
249
+ if np.isnan(mean2) or np.isnan(var2):
250
+ continue
251
+
252
+ std2 = np.sqrt(var2)
253
+ ci_lo2 = mean2 - ci_factor * std2
254
+ ci_hi2 = mean2 + ci_factor * std2
255
+
256
+ negligible2 = (ci_lo2 >= -negligible_threshold) and (ci_hi2 <= negligible_threshold)
257
+ ampable2 = np.abs(mean2) > ampable_threshold
258
+
259
+ # Three-level filtering
260
+ if negligible1 and negligible2:
261
+ # Both negligible - set to 1
262
+ overlap_batch[s, i, j] = 1.0
263
+ elif (negligible1 and ampable2) or (negligible2 and ampable1):
264
+ # Amplitude conflict - keep original
265
+ pass
266
+ else:
267
+ # Sign reversal case - apply probability-based filtering
268
+ std1_safe = max(std1, tol)
269
+ std2_safe = max(std2, tol)
270
+
271
+ # Normal CDF approximation (simplified for Numba)
272
+ # Using a simple approximation since scipy.stats.norm is not available in nopython mode
273
+ z1 = mean1 / std1_safe
274
+ z2 = mean2 / std2_safe
275
+
276
+ # Simple normal CDF approximation
277
+ # This is less accurate but much faster and Numba-compatible
278
+ def norm_cdf_approx(x):
279
+ # Approximation of normal CDF
280
+ t = 1.0 / (1.0 + 0.2316419 * np.abs(x))
281
+ d = 0.3989423 * np.exp(-x * x / 2.0)
282
+ prob = d * t * (0.3193815 + t * (-0.3565638 + t * (1.781478 + t * (-1.821256 + t * 1.330274))))
283
+ if x > 0:
284
+ return 1.0 - prob
285
+ else:
286
+ return prob
287
+
288
+ Ppos1 = norm_cdf_approx(z1)
289
+ Ppos2 = norm_cdf_approx(z2)
290
+
291
+ # Sign-mismatch probability
292
+ Pdiff_sign = Ppos1 * (1.0 - Ppos2) + (1.0 - Ppos1) * Ppos2
293
+
294
+ # Mean-difference penalty
295
+ mean_diff = np.abs(mean1 - mean2)
296
+ s_thresh = 0.2
297
+ e_thresh = 0.5
298
+
299
+ if mean_diff <= s_thresh:
300
+ penalty = 0.0
301
+ elif mean_diff >= e_thresh:
302
+ penalty = 1.0
303
+ else:
304
+ # Linear ramp (simplified from sigmoid)
305
+ penalty = (mean_diff - s_thresh) / (e_thresh - s_thresh)
306
+
307
+ # Apply combined penalty
308
+ Pdiff = max(Pdiff_sign, penalty)
309
+ output_diff = 1.0 - overlap_batch[s, i, j]
310
+ overlap_batch[s, i, j] = 1.0 - output_diff * Pdiff
311
+
312
+ return overlap_batch
313
+
314
+
315
+ def compute_overlap_batch_fallback(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
316
+ """
317
+ Fallback implementation when Numba is not available.
318
+ This is a simple, clean NumPy implementation without excessive overhead.
319
+ """
320
+ n_subjects, n_phases, n_features = means1_batch.shape
321
+ overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
322
+
323
+ for s in range(n_subjects):
324
+ # Pre-compute NaN masks for this subject
325
+ has_nan1 = np.any(np.isnan(means1_batch[s]) | np.isnan(vars1_batch[s]), axis=1)
326
+ has_nan2 = np.any(np.isnan(means2_batch[s]) | np.isnan(vars2_batch[s]), axis=1)
327
+
328
+ for i in range(150):
329
+ if has_nan1[i]:
330
+ continue
331
+
332
+ for j in range(150):
333
+ if has_nan2[j]:
334
+ continue
335
+
336
+ # Direct computation
337
+ diff = means1_batch[s, i] - means2_batch[s, j]
338
+ sum_var = vars1_batch[s, i] + vars2_batch[s, j]
339
+
340
+ # Check validity
341
+ if np.all(sum_var > tol):
342
+ quad_sum = np.sum(diff**2 / sum_var)
343
+ half_quad = 0.5 * quad_sum
344
+
345
+ if half_quad <= 20.0:
346
+ overlap_batch[s, i, j] = np.exp(-half_quad)
347
+
348
+ return overlap_batch
349
+
350
+
351
+ # Main interface function
352
+ def compute_overlap_batch(means1_batch, vars1_batch, means2_batch, vars2_batch,
353
+ tol=1e-12, biomechanical_filter=False, ultra_fast=True,
354
+ vectorized_mode='auto'):
355
+ """
356
+ Main interface for computing batch overlap with multiple vectorization modes.
357
+
358
+ Parameters:
359
+ ultra_fast: bool - Default True for maximum speed
360
+ vectorized_mode: str - 'auto', 'ultra_fast', 'vectorized', 'row_vectorized'
361
+ """
362
+ if NUMBA_AVAILABLE:
363
+ # Select best vectorization strategy
364
+ if vectorized_mode == 'auto':
365
+ # Auto-select based on data size
366
+ n_subjects, _, n_features = means1_batch.shape
367
+ if n_features >= 10 or n_subjects >= 15:
368
+ mode = 'row_vectorized' # Best for larger feature sets
369
+ elif n_features >= 4:
370
+ mode = 'vectorized' # Good for medium feature sets
371
+ else:
372
+ mode = 'ultra_fast' # Simple and fast for small feature sets
373
+ else:
374
+ mode = vectorized_mode
375
+
376
+ try:
377
+ if mode == 'row_vectorized':
378
+ overlap_batch = compute_overlap_batch_numba_row_vectorized(
379
+ means1_batch, vars1_batch, means2_batch, vars2_batch)
380
+ elif mode == 'vectorized':
381
+ overlap_batch = compute_overlap_batch_numba_vectorized(
382
+ means1_batch, vars1_batch, means2_batch, vars2_batch)
383
+ else: # ultra_fast
384
+ overlap_batch = compute_overlap_batch_numba_ultra_fast(
385
+ means1_batch, vars1_batch, means2_batch, vars2_batch)
386
+ except:
387
+ # Fallback to validated version if any optimized version fails
388
+ overlap_batch = compute_overlap_batch_numba(means1_batch, vars1_batch,
389
+ means2_batch, vars2_batch, tol)
390
+
391
+ if biomechanical_filter:
392
+ overlap_batch = apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
393
+ means2_batch, vars2_batch, tol)
394
+ else:
395
+ overlap_batch = compute_overlap_batch_fallback(means1_batch, vars1_batch,
396
+ means2_batch, vars2_batch, tol)
397
+
398
+ # Final clipping
399
+ np.clip(overlap_batch, 0.0, 1.0, out=overlap_batch)
400
+ return overlap_batch
401
+
402
+
403
+ if __name__ == "__main__":
404
+ # Simple test to verify it works
405
+ print("Testing Numba overlap calculation...")
406
+
407
+ # Create test data
408
+ n_subjects = 10
409
+ n_features = 20 # e.g., 10 time windows × 2 sensors
410
+
411
+ means1 = np.random.randn(n_subjects, 150, n_features)
412
+ vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
413
+ means2 = np.random.randn(n_subjects, 150, n_features)
414
+ vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
415
+
416
+ # Time the calculation
417
+ import time
418
+
419
+ print(f"Numba available: {NUMBA_AVAILABLE}")
420
+ print(f"Computing overlap for {n_subjects} subjects, {n_features} features...")
421
+
422
+ start = time.time()
423
+ result = compute_overlap_batch(means1, vars1, means2, vars2)
424
+ end = time.time()
425
+
426
+ print(f"Result shape: {result.shape}")
427
+ print(f"Execution time: {end - start:.3f} seconds")
428
+ print(f"Non-zero elements: {np.count_nonzero(result)}")
429
+ print(f"Max value: {np.max(result):.4f}")
430
+ print(f"Min value: {np.min(result):.4f}")
431
+
432
+ if NUMBA_AVAILABLE:
433
+ print("\n✅ Numba JIT compilation successful!")
434
+ print("The first run compiles the function, subsequent runs will be much faster.")
435
+
436
+ # Run again to show compiled performance
437
+ start = time.time()
438
+ result = compute_overlap_batch(means1, vars1, means2, vars2)
439
+ end = time.time()
440
+ print(f"Compiled execution time: {end - start:.3f} seconds")
441
+ else:
442
+ print("\n⚠️ Numba not available, using fallback implementation.")
443
+ print("Install Numba for 10-100x speedup: pip install numba")
pages/02_Tool.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
  from mpl_toolkits.axes_grid1 import make_axes_locatable
10
- from multivariate_gaussian_overlap import calculate_similarity_portrait_abstraction
11
  from plot_similarity import plot_similarity_measure
12
  from sensor_illustration import LegIllustration
13
  from shared_styling import set_plot_style, apply_cream_theme_to_figure
@@ -376,19 +376,32 @@ def calculate_overlap_measures(task1, task2, sensors, abstraction_level,
376
 
377
 
378
  # Get the phase windows from session state or use default
379
- if time_windows is None:
380
  time_windows = [1]
381
-
382
- similarity = calculate_similarity_portrait_abstraction(
383
- sensors=list(sensors),
384
- time_window=time_windows,
385
- abstraction_level=abstraction_level.lower(),
386
- task1_name=task1_spec,
387
- task2_name=task2_spec,
388
- output_difference=use_output_data,
389
- biomechanical_difference=biomechanical_filtering,
390
- progress_callback=lambda x: progress_bar.progress(x)
391
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  # Clear the progress indicators
394
  progress_placeholder.empty()
@@ -495,7 +508,7 @@ if st.session_state.authenticated:
495
  task2_name_backend = (task2, task2_incline, task2_speed)
496
 
497
  # This will be passed to calculate_overlap_measures, which then passes its .lower() version
498
- # to calculate_similarity_portrait_abstraction.
499
  abstraction_level_backend = analysis_detail
500
 
501
  progress_placeholder = st.empty()
@@ -682,4 +695,4 @@ if st.session_state.authenticated:
682
  st.markdown("--- Marginal Distributions ---")
683
  # Placeholder: Add logic to plot marginal distributions if required.
684
  # This was previously outside the main conditional blocks.
685
- st.write("(Marginal distribution plotting not implemented in this layout yet)")
 
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
  from mpl_toolkits.axes_grid1 import make_axes_locatable
10
+ from multivariate_gaussian_overlap import calculate_portrait
11
  from plot_similarity import plot_similarity_measure
12
  from sensor_illustration import LegIllustration
13
  from shared_styling import set_plot_style, apply_cream_theme_to_figure
 
376
 
377
 
378
  # Get the phase windows from session state or use default
379
+ if not time_windows:
380
  time_windows = [1]
381
+
382
+ # Configure inputs for new calculate_portrait API
383
+ portrait_kwargs = {
384
+ 'task1': task1_spec if any(task1_spec) else None,
385
+ 'task2': task2_spec if any(task2_spec) else None,
386
+ 'time_window': time_windows,
387
+ 'progress_callback': lambda x: progress_bar.progress(x),
388
+ 'method': 'auto'
389
+ }
390
+
391
+ if use_output_data:
392
+ portrait_kwargs.update({
393
+ 'sensors': None,
394
+ 'output_sensors': list(sensors),
395
+ 'biomechanical_filter': biomechanical_filtering
396
+ })
397
+ else:
398
+ portrait_kwargs.update({
399
+ 'sensors': list(sensors),
400
+ 'output_sensors': None,
401
+ 'biomechanical_filter': False
402
+ })
403
+
404
+ similarity = calculate_portrait(**portrait_kwargs)
405
 
406
  # Clear the progress indicators
407
  progress_placeholder.empty()
 
508
  task2_name_backend = (task2, task2_incline, task2_speed)
509
 
510
  # This will be passed to calculate_overlap_measures, which then passes its .lower() version
511
+ # to calculate_portrait.
512
  abstraction_level_backend = analysis_detail
513
 
514
  progress_placeholder = st.empty()
 
695
  st.markdown("--- Marginal Distributions ---")
696
  # Placeholder: Add logic to plot marginal distributions if required.
697
  # This was previously outside the main conditional blocks.
698
+ st.write("(Marginal distribution plotting not implemented in this layout yet)")
plot_similarity.py CHANGED
@@ -17,7 +17,6 @@ def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
17
  fontsize: int = 16,
18
  y_label_pad: int = 20,
19
  cbar_labels: bool = True,
20
- cutoff_treshold:float=None,
21
  high_level_plot: bool = False):
22
  """Plot similarity measure with consistent styling.
23
 
@@ -110,36 +109,13 @@ def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
110
  else:
111
  cbar_obj.set_ticks([])
112
 
113
- # Implement cutoff threshold annotation
114
- if cutoff_treshold is not None:
115
-
116
- # The actual cutoff will depend on the plot type
117
- if plot_type == 'input':
118
- true_cutoff = cutoff_treshold
119
- elif plot_type == 'output':
120
- true_cutoff = 1 - cutoff_treshold
121
- elif plot_type == 'output_biomechanical':
122
- true_cutoff = 1 - cutoff_treshold
123
- else: # conflict
124
- true_cutoff = cutoff_treshold * (1 - cutoff_treshold)
125
-
126
-
127
- # Count percent of values above threshold
128
- if high_level_plot:
129
- # Subtract diagonal from thresholded values
130
- mask = np.ones_like(measure_data, dtype=bool)
131
- np.fill_diagonal(mask, False)
132
- total = np.sum(mask)
133
- above = np.sum((measure_data > true_cutoff) & mask)
134
- else:
135
- total = measure_data.size
136
- above = np.sum(measure_data > true_cutoff)
137
- percent = 100.0 * above / total if total > 0 else 0.0
138
  # Format as e.g. "12.3% > 0.8"
139
- annotation = f"{percent:.1f}% > {true_cutoff:g}"
140
  # Place annotation to the right of the colorbar
141
  cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
142
- fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
143
 
144
  # Set aspect ratio to equal
145
  ax.set_aspect('equal')
 
17
  fontsize: int = 16,
18
  y_label_pad: int = 20,
19
  cbar_labels: bool = True,
 
20
  high_level_plot: bool = False):
21
  """Plot similarity measure with consistent styling.
22
 
 
109
  else:
110
  cbar_obj.set_ticks([])
111
 
112
+ if plot_type != 'output':
113
+ percent = np.mean(measure_data) * 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Format as e.g. "12.3% > 0.8"
115
+ annotation = r"$\tilde C_{total} = $" + f"{percent:.1f}%"
116
  # Place annotation to the right of the colorbar
117
  cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
118
+ fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
119
 
120
  # Set aspect ratio to equal
121
  ax.set_aspect('equal')
plot_styling.py CHANGED
@@ -1,13 +1,59 @@
1
  """
2
  Unified styling module for both Streamlit UI and matplotlib plots.
3
  Contains all styling definitions to ensure consistency across the application.
 
 
 
 
 
4
  """
5
 
6
- import streamlit as st
 
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
  import matplotlib.font_manager as fm
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ==========================
12
  # Shared Color Themes
13
  # ==========================
@@ -46,6 +92,40 @@ LIGHT_COLORS = {
46
  'spine_color': '#E0E0E0',
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Dark theme colors - consistent across UI and plots
50
  DARK_COLORS = {
51
  'background': '#1E1E1E',
@@ -80,8 +160,12 @@ DARK_COLORS = {
80
  'spine_color': '#505050',
81
  }
82
 
 
83
  def get_current_colors():
84
  """Get the current color scheme based on session state."""
 
 
 
85
  if 'dark_theme' not in st.session_state:
86
  st.session_state.dark_theme = False
87
 
@@ -89,12 +173,18 @@ def get_current_colors():
89
 
90
  def toggle_theme():
91
  """Toggle between light and dark themes."""
 
 
 
92
  if 'dark_theme' not in st.session_state:
93
  st.session_state.dark_theme = False
94
  st.session_state.dark_theme = not st.session_state.dark_theme
95
 
96
  def add_theme_toggle():
97
  """Add a theme toggle button to the sidebar."""
 
 
 
98
  with st.sidebar:
99
  st.markdown("---")
100
  current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
@@ -141,9 +231,18 @@ purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
141
  my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
142
  reverse=True, as_cmap=True)
143
 
144
- def get_plot_style():
145
- """Get plot style with current theme colors."""
146
- theme_colors = get_current_colors()
 
 
 
 
 
 
 
 
 
147
 
148
  return {
149
  'font_family': PLOT_STYLE_FONT_FAMILY,
@@ -168,9 +267,13 @@ def get_plot_style():
168
  'spine_color': theme_colors['spine_color'],
169
  }
170
 
171
- def set_plot_style():
172
- """Set consistent plot styling across all figures"""
173
- plot_style = get_plot_style()
 
 
 
 
174
 
175
  plt.rcParams['font.family'] = plot_style['font_family']
176
  plt.rcParams['font.size'] = plot_style['font_size']
@@ -233,6 +336,10 @@ def apply_theme_to_figure(fig, ax=None):
233
 
234
  return fig, ax
235
 
 
 
 
 
236
  # Legacy function name for backward compatibility
237
  def apply_cream_theme_to_figure(fig, ax=None):
238
  """Apply current theme to an existing figure and axes (legacy function name)"""
@@ -461,19 +568,27 @@ def get_tool_page_css():
461
 
462
  def apply_base_styling():
463
  """Apply the base styling to the current Streamlit page."""
 
 
464
  st.markdown(get_base_css(), unsafe_allow_html=True)
465
 
466
  def apply_home_page_styling():
467
  """Apply styling specific to the home page."""
 
 
468
  st.markdown(get_base_css(), unsafe_allow_html=True)
469
  st.markdown(get_home_page_css(), unsafe_allow_html=True)
470
 
471
  def apply_documentation_page_styling():
472
  """Apply styling specific to the documentation page."""
 
 
473
  st.markdown(get_base_css(), unsafe_allow_html=True)
474
  st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
475
 
476
  def apply_tool_page_styling():
477
  """Apply styling specific to the analysis tool page."""
 
 
478
  st.markdown(get_base_css(), unsafe_allow_html=True)
479
  st.markdown(get_tool_page_css(), unsafe_allow_html=True)
 
1
  """
2
  Unified styling module for both Streamlit UI and matplotlib plots.
3
  Contains all styling definitions to ensure consistency across the application.
4
+
5
+ Note: When used outside of Streamlit environment (e.g., in Jupyter notebooks),
6
+ you may see warnings about missing ScriptRunContext or Session state. These
7
+ warnings are harmless and can be safely ignored - the core plotting functions
8
+ (get_plot_style, set_plot_style, PLOT_COLORS) work correctly regardless.
9
  """
10
 
11
+ import warnings
12
+ import logging
13
  import matplotlib.pyplot as plt
14
  import seaborn as sns
15
  import matplotlib.font_manager as fm
16
 
17
+ # Suppress Streamlit warnings when running outside streamlit environment
18
+ warnings.filterwarnings('ignore', category=UserWarning, module='streamlit')
19
+ warnings.filterwarnings('ignore', message='.*ScriptRunContext.*')
20
+ warnings.filterwarnings('ignore', message='.*Session state.*')
21
+ warnings.filterwarnings('ignore', message='.*missing ScriptRunContext.*')
22
+ warnings.filterwarnings('ignore', message='.*does not function when running.*')
23
+ warnings.filterwarnings('ignore', module='streamlit.runtime.*')
24
+ warnings.filterwarnings('ignore', module='streamlit.runtime.scriptrunner_utils.*')
25
+ warnings.filterwarnings('ignore', module='streamlit.runtime.state.*')
26
+
27
+ # Suppress Streamlit loggers that generate warnings outside streamlit environment
28
+ logging.getLogger('streamlit.runtime.scriptrunner_utils.script_run_context').setLevel(logging.ERROR)
29
+ logging.getLogger('streamlit.runtime.state.session_state_proxy').setLevel(logging.ERROR)
30
+ logging.getLogger('streamlit').setLevel(logging.ERROR)
31
+
32
+ try:
33
+ # Set logging level before importing to suppress initial warnings
34
+ for logger_name in ['streamlit', 'streamlit.runtime', 'streamlit.runtime.scriptrunner_utils',
35
+ 'streamlit.runtime.state', 'streamlit.runtime.scriptrunner_utils.script_run_context',
36
+ 'streamlit.runtime.state.session_state_proxy']:
37
+ logging.getLogger(logger_name).setLevel(logging.ERROR)
38
+
39
+ import streamlit as st
40
+ _STREAMLIT_AVAILABLE = True
41
+ except ImportError:
42
+ _STREAMLIT_AVAILABLE = False
43
+ # Create a mock streamlit module for non-streamlit environments
44
+ class MockStreamlit:
45
+ class session_state:
46
+ dark_theme = False
47
+ st = MockStreamlit()
48
+
49
+ def _suppress_streamlit_warnings(func):
50
+ """Decorator to suppress streamlit warnings in functions."""
51
+ def wrapper(*args, **kwargs):
52
+ with warnings.catch_warnings():
53
+ warnings.simplefilter('ignore')
54
+ return func(*args, **kwargs)
55
+ return wrapper
56
+
57
  # ==========================
58
  # Shared Color Themes
59
  # ==========================
 
92
  'spine_color': '#E0E0E0',
93
  }
94
 
95
+ # Paper theme colors - pure white backgrounds for publication
96
+ PAPER_COLORS = {
97
+ 'background': '#FFFFFF',
98
+ 'figure_background': '#FFFFFF',
99
+ 'sidebar_bg_start': '#FAFAFA',
100
+ 'sidebar_bg_end': '#F5F5F5',
101
+ 'border_light': '#F0F0F0',
102
+ 'border_medium': '#E0E0E0',
103
+ 'text_primary': '#2C3E50',
104
+ 'text_secondary': '#5D6D7E',
105
+ 'text_tertiary': '#85929E',
106
+ 'text_light': '#A6ACAF',
107
+ 'button_bg_start': '#5D6D7E',
108
+ 'button_bg_end': '#85929E',
109
+ 'button_hover_start': '#85929E',
110
+ 'button_hover_end': '#A6ACAF',
111
+ 'alert_error_bg': '#FFE6E6',
112
+ 'alert_error_border': '#FFAAAA',
113
+ 'alert_error_text': '#CC0000',
114
+ 'alert_info_bg': '#E6F3FF',
115
+ 'alert_info_border': '#99D6FF',
116
+ 'alert_info_text': '#0066CC',
117
+ 'warning_bg': '#FFF8E1',
118
+ 'warning_border': '#FF9800',
119
+ 'success_bg': '#E8F5E8',
120
+ 'success_border': '#4CAF50',
121
+ 'generate_button_bg': '#228B22',
122
+ 'generate_button_hover': '#32CD32',
123
+ # Plot-specific colors - pure white for papers
124
+ 'axes_background': '#FFFFFF',
125
+ 'grid_color': '#F0F0F0',
126
+ 'spine_color': '#E0E0E0',
127
+ }
128
+
129
  # Dark theme colors - consistent across UI and plots
130
  DARK_COLORS = {
131
  'background': '#1E1E1E',
 
160
  'spine_color': '#505050',
161
  }
162
 
163
+ @_suppress_streamlit_warnings
164
  def get_current_colors():
165
  """Get the current color scheme based on session state."""
166
+ if not _STREAMLIT_AVAILABLE:
167
+ return LIGHT_COLORS
168
+
169
  if 'dark_theme' not in st.session_state:
170
  st.session_state.dark_theme = False
171
 
 
173
 
174
  def toggle_theme():
175
  """Toggle between light and dark themes."""
176
+ if not _STREAMLIT_AVAILABLE:
177
+ return
178
+
179
  if 'dark_theme' not in st.session_state:
180
  st.session_state.dark_theme = False
181
  st.session_state.dark_theme = not st.session_state.dark_theme
182
 
183
  def add_theme_toggle():
184
  """Add a theme toggle button to the sidebar."""
185
+ if not _STREAMLIT_AVAILABLE:
186
+ return
187
+
188
  with st.sidebar:
189
  st.markdown("---")
190
  current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
 
231
  my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
232
  reverse=True, as_cmap=True)
233
 
234
+ def get_plot_style(style='default'):
235
+ """Get plot style with specified color theme.
236
+
237
+ Args:
238
+ style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
239
+ """
240
+ if style == 'paper':
241
+ theme_colors = PAPER_COLORS
242
+ elif style == 'dark':
243
+ theme_colors = DARK_COLORS
244
+ else: # default
245
+ theme_colors = get_current_colors()
246
 
247
  return {
248
  'font_family': PLOT_STYLE_FONT_FAMILY,
 
267
  'spine_color': theme_colors['spine_color'],
268
  }
269
 
270
+ def set_plot_style(style='default'):
271
+ """Set consistent plot styling across all figures.
272
+
273
+ Args:
274
+ style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
275
+ """
276
+ plot_style = get_plot_style(style=style)
277
 
278
  plt.rcParams['font.family'] = plot_style['font_family']
279
  plt.rcParams['font.size'] = plot_style['font_size']
 
336
 
337
  return fig, ax
338
 
339
+ def set_paper_plot_style():
340
+ """Convenience function to set pure white backgrounds for paper publication."""
341
+ set_plot_style(style='paper')
342
+
343
  # Legacy function name for backward compatibility
344
  def apply_cream_theme_to_figure(fig, ax=None):
345
  """Apply current theme to an existing figure and axes (legacy function name)"""
 
568
 
569
  def apply_base_styling():
570
  """Apply the base styling to the current Streamlit page."""
571
+ if not _STREAMLIT_AVAILABLE:
572
+ return
573
  st.markdown(get_base_css(), unsafe_allow_html=True)
574
 
575
  def apply_home_page_styling():
576
  """Apply styling specific to the home page."""
577
+ if not _STREAMLIT_AVAILABLE:
578
+ return
579
  st.markdown(get_base_css(), unsafe_allow_html=True)
580
  st.markdown(get_home_page_css(), unsafe_allow_html=True)
581
 
582
  def apply_documentation_page_styling():
583
  """Apply styling specific to the documentation page."""
584
+ if not _STREAMLIT_AVAILABLE:
585
+ return
586
  st.markdown(get_base_css(), unsafe_allow_html=True)
587
  st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
588
 
589
  def apply_tool_page_styling():
590
  """Apply styling specific to the analysis tool page."""
591
+ if not _STREAMLIT_AVAILABLE:
592
+ return
593
  st.markdown(get_base_css(), unsafe_allow_html=True)
594
  st.markdown(get_tool_page_css(), unsafe_allow_html=True)
requirements.txt CHANGED
@@ -3,4 +3,5 @@ numpy>=1.21.0
3
  matplotlib>=3.4.0
4
  seaborn>=0.11.0
5
  pandas>=1.3.0
6
- scipy>=1.7.0
 
 
3
  matplotlib>=3.4.0
4
  seaborn>=0.11.0
5
  pandas>=1.3.0
6
+ scipy>=1.7.0
7
+ tqdm>=4.64.0
shared_styling.py CHANGED
@@ -1,13 +1,59 @@
1
  """
2
  Unified styling module for both Streamlit UI and matplotlib plots.
3
  Contains all styling definitions to ensure consistency across the application.
 
 
 
 
 
4
  """
5
 
6
- import streamlit as st
 
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
  import matplotlib.font_manager as fm
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ==========================
12
  # Shared Color Themes
13
  # ==========================
@@ -46,6 +92,40 @@ LIGHT_COLORS = {
46
  'spine_color': '#E0E0E0',
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Dark theme colors - consistent across UI and plots
50
  DARK_COLORS = {
51
  'background': '#1E1E1E',
@@ -80,8 +160,12 @@ DARK_COLORS = {
80
  'spine_color': '#505050',
81
  }
82
 
 
83
  def get_current_colors():
84
  """Get the current color scheme based on session state."""
 
 
 
85
  if 'dark_theme' not in st.session_state:
86
  st.session_state.dark_theme = False
87
 
@@ -89,12 +173,18 @@ def get_current_colors():
89
 
90
  def toggle_theme():
91
  """Toggle between light and dark themes."""
 
 
 
92
  if 'dark_theme' not in st.session_state:
93
  st.session_state.dark_theme = False
94
  st.session_state.dark_theme = not st.session_state.dark_theme
95
 
96
  def add_theme_toggle():
97
  """Add a theme toggle button to the sidebar."""
 
 
 
98
  with st.sidebar:
99
  st.markdown("---")
100
  current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
@@ -141,9 +231,18 @@ purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
141
  my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
142
  reverse=True, as_cmap=True)
143
 
144
- def get_plot_style():
145
- """Get plot style with current theme colors."""
146
- theme_colors = get_current_colors()
 
 
 
 
 
 
 
 
 
147
 
148
  return {
149
  'font_family': PLOT_STYLE_FONT_FAMILY,
@@ -168,9 +267,13 @@ def get_plot_style():
168
  'spine_color': theme_colors['spine_color'],
169
  }
170
 
171
- def set_plot_style():
172
- """Set consistent plot styling across all figures"""
173
- plot_style = get_plot_style()
 
 
 
 
174
 
175
  plt.rcParams['font.family'] = plot_style['font_family']
176
  plt.rcParams['font.size'] = plot_style['font_size']
@@ -233,6 +336,10 @@ def apply_theme_to_figure(fig, ax=None):
233
 
234
  return fig, ax
235
 
 
 
 
 
236
  # Legacy function name for backward compatibility
237
  def apply_cream_theme_to_figure(fig, ax=None):
238
  """Apply current theme to an existing figure and axes (legacy function name)"""
@@ -461,19 +568,27 @@ def get_tool_page_css():
461
 
462
  def apply_base_styling():
463
  """Apply the base styling to the current Streamlit page."""
 
 
464
  st.markdown(get_base_css(), unsafe_allow_html=True)
465
 
466
  def apply_home_page_styling():
467
  """Apply styling specific to the home page."""
 
 
468
  st.markdown(get_base_css(), unsafe_allow_html=True)
469
  st.markdown(get_home_page_css(), unsafe_allow_html=True)
470
 
471
  def apply_documentation_page_styling():
472
  """Apply styling specific to the documentation page."""
 
 
473
  st.markdown(get_base_css(), unsafe_allow_html=True)
474
  st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
475
 
476
  def apply_tool_page_styling():
477
  """Apply styling specific to the analysis tool page."""
 
 
478
  st.markdown(get_base_css(), unsafe_allow_html=True)
479
  st.markdown(get_tool_page_css(), unsafe_allow_html=True)
 
1
  """
2
  Unified styling module for both Streamlit UI and matplotlib plots.
3
  Contains all styling definitions to ensure consistency across the application.
4
+
5
+ Note: When used outside of Streamlit environment (e.g., in Jupyter notebooks),
6
+ you may see warnings about missing ScriptRunContext or Session state. These
7
+ warnings are harmless and can be safely ignored - the core plotting functions
8
+ (get_plot_style, set_plot_style, PLOT_COLORS) work correctly regardless.
9
  """
10
 
11
+ import warnings
12
+ import logging
13
  import matplotlib.pyplot as plt
14
  import seaborn as sns
15
  import matplotlib.font_manager as fm
16
 
17
+ # Suppress Streamlit warnings when running outside streamlit environment
18
+ warnings.filterwarnings('ignore', category=UserWarning, module='streamlit')
19
+ warnings.filterwarnings('ignore', message='.*ScriptRunContext.*')
20
+ warnings.filterwarnings('ignore', message='.*Session state.*')
21
+ warnings.filterwarnings('ignore', message='.*missing ScriptRunContext.*')
22
+ warnings.filterwarnings('ignore', message='.*does not function when running.*')
23
+ warnings.filterwarnings('ignore', module='streamlit.runtime.*')
24
+ warnings.filterwarnings('ignore', module='streamlit.runtime.scriptrunner_utils.*')
25
+ warnings.filterwarnings('ignore', module='streamlit.runtime.state.*')
26
+
27
+ # Suppress Streamlit loggers that generate warnings outside streamlit environment
28
+ logging.getLogger('streamlit.runtime.scriptrunner_utils.script_run_context').setLevel(logging.ERROR)
29
+ logging.getLogger('streamlit.runtime.state.session_state_proxy').setLevel(logging.ERROR)
30
+ logging.getLogger('streamlit').setLevel(logging.ERROR)
31
+
32
+ try:
33
+ # Set logging level before importing to suppress initial warnings
34
+ for logger_name in ['streamlit', 'streamlit.runtime', 'streamlit.runtime.scriptrunner_utils',
35
+ 'streamlit.runtime.state', 'streamlit.runtime.scriptrunner_utils.script_run_context',
36
+ 'streamlit.runtime.state.session_state_proxy']:
37
+ logging.getLogger(logger_name).setLevel(logging.ERROR)
38
+
39
+ import streamlit as st
40
+ _STREAMLIT_AVAILABLE = True
41
+ except ImportError:
42
+ _STREAMLIT_AVAILABLE = False
43
+ # Create a mock streamlit module for non-streamlit environments
44
+ class MockStreamlit:
45
+ class session_state:
46
+ dark_theme = False
47
+ st = MockStreamlit()
48
+
49
+ def _suppress_streamlit_warnings(func):
50
+ """Decorator to suppress streamlit warnings in functions."""
51
+ def wrapper(*args, **kwargs):
52
+ with warnings.catch_warnings():
53
+ warnings.simplefilter('ignore')
54
+ return func(*args, **kwargs)
55
+ return wrapper
56
+
57
  # ==========================
58
  # Shared Color Themes
59
  # ==========================
 
92
  'spine_color': '#E0E0E0',
93
  }
94
 
95
+ # Paper theme colors - pure white backgrounds for publication
96
+ PAPER_COLORS = {
97
+ 'background': '#FFFFFF',
98
+ 'figure_background': '#FFFFFF',
99
+ 'sidebar_bg_start': '#FAFAFA',
100
+ 'sidebar_bg_end': '#F5F5F5',
101
+ 'border_light': '#F0F0F0',
102
+ 'border_medium': '#E0E0E0',
103
+ 'text_primary': '#2C3E50',
104
+ 'text_secondary': '#5D6D7E',
105
+ 'text_tertiary': '#85929E',
106
+ 'text_light': '#A6ACAF',
107
+ 'button_bg_start': '#5D6D7E',
108
+ 'button_bg_end': '#85929E',
109
+ 'button_hover_start': '#85929E',
110
+ 'button_hover_end': '#A6ACAF',
111
+ 'alert_error_bg': '#FFE6E6',
112
+ 'alert_error_border': '#FFAAAA',
113
+ 'alert_error_text': '#CC0000',
114
+ 'alert_info_bg': '#E6F3FF',
115
+ 'alert_info_border': '#99D6FF',
116
+ 'alert_info_text': '#0066CC',
117
+ 'warning_bg': '#FFF8E1',
118
+ 'warning_border': '#FF9800',
119
+ 'success_bg': '#E8F5E8',
120
+ 'success_border': '#4CAF50',
121
+ 'generate_button_bg': '#228B22',
122
+ 'generate_button_hover': '#32CD32',
123
+ # Plot-specific colors - pure white for papers
124
+ 'axes_background': '#FFFFFF',
125
+ 'grid_color': '#F0F0F0',
126
+ 'spine_color': '#E0E0E0',
127
+ }
128
+
129
  # Dark theme colors - consistent across UI and plots
130
  DARK_COLORS = {
131
  'background': '#1E1E1E',
 
160
  'spine_color': '#505050',
161
  }
162
 
163
+ @_suppress_streamlit_warnings
164
  def get_current_colors():
165
  """Get the current color scheme based on session state."""
166
+ if not _STREAMLIT_AVAILABLE:
167
+ return LIGHT_COLORS
168
+
169
  if 'dark_theme' not in st.session_state:
170
  st.session_state.dark_theme = False
171
 
 
173
 
174
  def toggle_theme():
175
  """Toggle between light and dark themes."""
176
+ if not _STREAMLIT_AVAILABLE:
177
+ return
178
+
179
  if 'dark_theme' not in st.session_state:
180
  st.session_state.dark_theme = False
181
  st.session_state.dark_theme = not st.session_state.dark_theme
182
 
183
  def add_theme_toggle():
184
  """Add a theme toggle button to the sidebar."""
185
+ if not _STREAMLIT_AVAILABLE:
186
+ return
187
+
188
  with st.sidebar:
189
  st.markdown("---")
190
  current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
 
231
  my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
232
  reverse=True, as_cmap=True)
233
 
234
+ def get_plot_style(style='default'):
235
+ """Get plot style with specified color theme.
236
+
237
+ Args:
238
+ style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
239
+ """
240
+ if style == 'paper':
241
+ theme_colors = PAPER_COLORS
242
+ elif style == 'dark':
243
+ theme_colors = DARK_COLORS
244
+ else: # default
245
+ theme_colors = get_current_colors()
246
 
247
  return {
248
  'font_family': PLOT_STYLE_FONT_FAMILY,
 
267
  'spine_color': theme_colors['spine_color'],
268
  }
269
 
270
+ def set_plot_style(style='default'):
271
+ """Set consistent plot styling across all figures.
272
+
273
+ Args:
274
+ style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
275
+ """
276
+ plot_style = get_plot_style(style=style)
277
 
278
  plt.rcParams['font.family'] = plot_style['font_family']
279
  plt.rcParams['font.size'] = plot_style['font_size']
 
336
 
337
  return fig, ax
338
 
339
+ def set_paper_plot_style():
340
+ """Convenience function to set pure white backgrounds for paper publication."""
341
+ set_plot_style(style='paper')
342
+
343
  # Legacy function name for backward compatibility
344
  def apply_cream_theme_to_figure(fig, ax=None):
345
  """Apply current theme to an existing figure and axes (legacy function name)"""
 
568
 
569
  def apply_base_styling():
570
  """Apply the base styling to the current Streamlit page."""
571
+ if not _STREAMLIT_AVAILABLE:
572
+ return
573
  st.markdown(get_base_css(), unsafe_allow_html=True)
574
 
575
  def apply_home_page_styling():
576
  """Apply styling specific to the home page."""
577
+ if not _STREAMLIT_AVAILABLE:
578
+ return
579
  st.markdown(get_base_css(), unsafe_allow_html=True)
580
  st.markdown(get_home_page_css(), unsafe_allow_html=True)
581
 
582
  def apply_documentation_page_styling():
583
  """Apply styling specific to the documentation page."""
584
+ if not _STREAMLIT_AVAILABLE:
585
+ return
586
  st.markdown(get_base_css(), unsafe_allow_html=True)
587
  st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
588
 
589
  def apply_tool_page_styling():
590
  """Apply styling specific to the analysis tool page."""
591
+ if not _STREAMLIT_AVAILABLE:
592
+ return
593
  st.markdown(get_base_css(), unsafe_allow_html=True)
594
  st.markdown(get_tool_page_css(), unsafe_allow_html=True)
st_logs/dashboard_access.log ADDED
File without changes