Spaces:
Runtime error
Runtime error
Updated to new data and multivaraite lib api
Browse files- __pycache__/shared_styling.cpython-312.pyc +0 -0
- cached_data/precalculated_stats.pkl.gz +2 -2
- config.py +58 -15
- gpu_overlap.py +892 -0
- multivariate_gaussian_overlap.py +0 -0
- numba_overlap.py +443 -0
- pages/02_Tool.py +28 -15
- plot_similarity.py +4 -28
- plot_styling.py +122 -7
- requirements.txt +2 -1
- shared_styling.py +122 -7
- st_logs/dashboard_access.log +0 -0
__pycache__/shared_styling.cpython-312.pyc
DELETED
|
Binary file (32.5 kB)
|
|
|
cached_data/precalculated_stats.pkl.gz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be3355088c6913cdb552802f2e3aa139abed2a3a79441895d31d08ff5f9cd23a
|
| 3 |
+
size 25443689
|
config.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
| 1 |
# Configuration file for dashboard and preprocessing scripts
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
AVAILABLE_SENSORS = [
|
|
|
|
| 4 |
'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
|
| 5 |
'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
|
| 6 |
'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
|
| 7 |
'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l',
|
| 8 |
-
#
|
| 9 |
-
'pelvis_angle_s_r', 'pelvis_angle_s_l',
|
| 10 |
-
|
| 11 |
-
'
|
| 12 |
-
'
|
| 13 |
-
|
| 14 |
-
'
|
| 15 |
-
'ankle_angle_t_r', 'ankle_angle_t_l', 'ankle_vel_t_r', 'ankle_vel_t_l'
|
| 16 |
]
|
| 17 |
|
| 18 |
AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
|
|
@@ -26,24 +30,63 @@ ANALYSIS_ABSTRACTION_LEVELS = ['High', 'Medium/Low']
|
|
| 26 |
|
| 27 |
# Task configurations for pre-calculation and analysis
|
| 28 |
LOW_LEVEL_TASKS = [
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
('level_walking', 0.0, 0.8),
|
| 33 |
('level_walking', 0.0, 1.0),
|
| 34 |
('level_walking', 0.0, 1.2),
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
('level_walking', 0.0, 0.6),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
('level_walking', 0.0, 1.8),
|
|
|
|
|
|
|
|
|
|
| 38 |
('incline_walking', 5.0, 0.8),
|
| 39 |
('incline_walking', 5.0, 1.0),
|
| 40 |
('incline_walking', 5.0, 1.2),
|
| 41 |
-
('decline_walking', -5.0, 0.8),
|
| 42 |
-
('decline_walking', -5.0, 1.0),
|
| 43 |
-
('decline_walking', -5.0, 1.2),
|
| 44 |
('incline_walking', 10.0, 0.8),
|
| 45 |
('incline_walking', 10.0, 1.0),
|
| 46 |
('incline_walking', 10.0, 1.2),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
('decline_walking', -10.0, 0.8),
|
| 48 |
('decline_walking', -10.0, 1.0),
|
| 49 |
('decline_walking', -10.0, 1.2),
|
|
|
|
| 1 |
# Configuration file for dashboard and preprocessing scripts
|
| 2 |
|
| 3 |
+
# Data source control
|
| 4 |
+
USE_FILTERED_DATA = True # If True: use *_filtered.parquet files (standard speeds only)
|
| 5 |
+
# If False: use original .parquet files (all speeds, more data)
|
| 6 |
+
|
| 7 |
AVAILABLE_SENSORS = [
|
| 8 |
+
# Core sagittal plane sensors (available in both UMich and GTech)
|
| 9 |
'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
|
| 10 |
'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
|
| 11 |
'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
|
| 12 |
'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l',
|
| 13 |
+
# Pelvis and trunk angles (available in both datasets)
|
| 14 |
+
'pelvis_angle_s_r', 'pelvis_angle_s_l',
|
| 15 |
+
# Calculated segment sensors (computed by loader.py)
|
| 16 |
+
'shank_angle_s_r', 'shank_angle_s_l', 'shank_vel_s_r', 'shank_vel_s_l',
|
| 17 |
+
'thigh_angle_s_r', 'thigh_angle_s_l', 'thigh_vel_s_r', 'thigh_vel_s_l',
|
| 18 |
+
# Ground reaction forces (keeping only vertical right GRF)
|
| 19 |
+
'grf_y_r',
|
|
|
|
| 20 |
]
|
| 21 |
|
| 22 |
AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
|
|
|
|
| 30 |
|
| 31 |
# Task configurations for pre-calculation and analysis
|
| 32 |
LOW_LEVEL_TASKS = [
|
| 33 |
+
# Stair tasks - using incline field for height in mm, speed is None (self-selected)
|
| 34 |
+
# Original GTech stairs (102mm)
|
| 35 |
+
('stair_descent', 102.0, None),
|
| 36 |
+
('stair_ascent', 102.0, None),
|
| 37 |
+
|
| 38 |
+
# Additional stair heights (new datasets)
|
| 39 |
+
('stair_descent', 127.0, None),
|
| 40 |
+
('stair_ascent', 127.0, None),
|
| 41 |
+
('stair_descent', 152.0, None),
|
| 42 |
+
('stair_ascent', 152.0, None),
|
| 43 |
+
('stair_descent', 178.0, None),
|
| 44 |
+
('stair_ascent', 178.0, None),
|
| 45 |
+
|
| 46 |
+
# Level walking - UMich steady-state speeds
|
| 47 |
('level_walking', 0.0, 0.8),
|
| 48 |
('level_walking', 0.0, 1.0),
|
| 49 |
('level_walking', 0.0, 1.2),
|
| 50 |
+
|
| 51 |
+
# Level walking - All GTech speeds
|
| 52 |
+
('level_walking', 0.0, 0.5),
|
| 53 |
+
('level_walking', 0.0, 0.55),
|
| 54 |
('level_walking', 0.0, 0.6),
|
| 55 |
+
('level_walking', 0.0, 0.65),
|
| 56 |
+
('level_walking', 0.0, 0.7),
|
| 57 |
+
('level_walking', 0.0, 0.75),
|
| 58 |
+
('level_walking', 0.0, 0.85),
|
| 59 |
+
('level_walking', 0.0, 0.9),
|
| 60 |
+
('level_walking', 0.0, 0.95),
|
| 61 |
+
('level_walking', 0.0, 1.05),
|
| 62 |
+
('level_walking', 0.0, 1.1),
|
| 63 |
+
('level_walking', 0.0, 1.15),
|
| 64 |
+
('level_walking', 0.0, 1.25),
|
| 65 |
+
('level_walking', 0.0, 1.3),
|
| 66 |
+
('level_walking', 0.0, 1.35),
|
| 67 |
+
('level_walking', 0.0, 1.4),
|
| 68 |
+
('level_walking', 0.0, 1.45),
|
| 69 |
+
('level_walking', 0.0, 1.5),
|
| 70 |
+
('level_walking', 0.0, 1.55),
|
| 71 |
+
('level_walking', 0.0, 1.6),
|
| 72 |
+
('level_walking', 0.0, 1.65),
|
| 73 |
+
('level_walking', 0.0, 1.7),
|
| 74 |
+
('level_walking', 0.0, 1.75),
|
| 75 |
('level_walking', 0.0, 1.8),
|
| 76 |
+
('level_walking', 0.0, 1.85),
|
| 77 |
+
|
| 78 |
+
# Incline walking - UMich speeds
|
| 79 |
('incline_walking', 5.0, 0.8),
|
| 80 |
('incline_walking', 5.0, 1.0),
|
| 81 |
('incline_walking', 5.0, 1.2),
|
|
|
|
|
|
|
|
|
|
| 82 |
('incline_walking', 10.0, 0.8),
|
| 83 |
('incline_walking', 10.0, 1.0),
|
| 84 |
('incline_walking', 10.0, 1.2),
|
| 85 |
+
|
| 86 |
+
# Decline walking - UMich speeds
|
| 87 |
+
('decline_walking', -5.0, 0.8),
|
| 88 |
+
('decline_walking', -5.0, 1.0),
|
| 89 |
+
('decline_walking', -5.0, 1.2),
|
| 90 |
('decline_walking', -10.0, 0.8),
|
| 91 |
('decline_walking', -10.0, 1.0),
|
| 92 |
('decline_walking', -10.0, 1.2),
|
gpu_overlap.py
ADDED
|
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GPU-accelerated implementation of multivariate Gaussian overlap calculation using CuPy.
|
| 4 |
+
This provides massive speedup for large-scale analyses by processing all task pairs simultaneously.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
# Try to import CuPy for GPU acceleration
|
| 12 |
+
try:
|
| 13 |
+
import cupy as cp
|
| 14 |
+
GPU_AVAILABLE = True
|
| 15 |
+
print("✅ CuPy GPU acceleration available")
|
| 16 |
+
except ImportError:
|
| 17 |
+
GPU_AVAILABLE = False
|
| 18 |
+
cp = None
|
| 19 |
+
print("⚠️ CuPy not available. Install with: pip install cupy-cuda12x")
|
| 20 |
+
|
| 21 |
+
# Check for CUDA availability
|
| 22 |
+
if GPU_AVAILABLE:
|
| 23 |
+
try:
|
| 24 |
+
# Test if CUDA is actually available
|
| 25 |
+
device = cp.cuda.Device(0)
|
| 26 |
+
GPU_READY = True
|
| 27 |
+
print(f"🚀 GPU ready: Device {device.id} (RTX detected)")
|
| 28 |
+
except:
|
| 29 |
+
GPU_READY = False
|
| 30 |
+
GPU_AVAILABLE = False
|
| 31 |
+
print("⚠️ CUDA not available, disabling GPU acceleration")
|
| 32 |
+
else:
|
| 33 |
+
GPU_READY = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def compute_overlap_batch_gpu(means1_batch, vars1_batch, means2_batch, vars2_batch,
|
| 37 |
+
tol=1e-12, biomechanical_filter=False):
|
| 38 |
+
"""
|
| 39 |
+
GPU-accelerated batch overlap computation using CuPy.
|
| 40 |
+
|
| 41 |
+
Processes all subjects simultaneously with full GPU vectorization.
|
| 42 |
+
This is the "throw everything in" approach for maximum GPU utilization.
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
means1_batch: np.ndarray shape (n_subjects, 150, n_features)
|
| 46 |
+
vars1_batch: np.ndarray shape (n_subjects, 150, n_features)
|
| 47 |
+
means2_batch: np.ndarray shape (n_subjects, 150, n_features)
|
| 48 |
+
vars2_batch: np.ndarray shape (n_subjects, 150, n_features)
|
| 49 |
+
tol: float, tolerance for variance validity
|
| 50 |
+
biomechanical_filter: bool, apply biomechanical filtering
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
np.ndarray shape (n_subjects, 150, 150) - overlap values
|
| 54 |
+
"""
|
| 55 |
+
if not GPU_AVAILABLE:
|
| 56 |
+
raise RuntimeError("CuPy not available for GPU computation")
|
| 57 |
+
|
| 58 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 59 |
+
|
| 60 |
+
# Transfer to GPU - single transfer for all data
|
| 61 |
+
means1_gpu = cp.asarray(means1_batch, dtype=cp.float32)
|
| 62 |
+
vars1_gpu = cp.asarray(vars1_batch, dtype=cp.float32)
|
| 63 |
+
means2_gpu = cp.asarray(means2_batch, dtype=cp.float32)
|
| 64 |
+
vars2_gpu = cp.asarray(vars2_batch, dtype=cp.float32)
|
| 65 |
+
|
| 66 |
+
# Pre-allocate output on GPU
|
| 67 |
+
overlap_batch_gpu = cp.zeros((n_subjects, 150, 150), dtype=cp.float32)
|
| 68 |
+
|
| 69 |
+
# CRITICAL OPTIMIZATION: Use broadcasting to compute ALL phase pairs at once
|
| 70 |
+
# Shape transformations for broadcasting:
|
| 71 |
+
# means1: (n_subjects, 150, 1, n_features) - for phase_i
|
| 72 |
+
# means2: (n_subjects, 1, 150, n_features) - for phase_j
|
| 73 |
+
# Result: (n_subjects, 150, 150, n_features) - all pairs
|
| 74 |
+
|
| 75 |
+
means1_exp = means1_gpu[:, :, cp.newaxis, :] # (n_subjects, 150, 1, n_features)
|
| 76 |
+
vars1_exp = vars1_gpu[:, :, cp.newaxis, :]
|
| 77 |
+
means2_exp = means2_gpu[:, cp.newaxis, :, :] # (n_subjects, 1, 150, n_features)
|
| 78 |
+
vars2_exp = vars2_gpu[:, cp.newaxis, :, :]
|
| 79 |
+
|
| 80 |
+
# Compute all differences and variance sums simultaneously
|
| 81 |
+
diff = means1_exp - means2_exp # Shape: (n_subjects, 150, 150, n_features)
|
| 82 |
+
var_sum = vars1_exp + vars2_exp # Shape: (n_subjects, 150, 150, n_features)
|
| 83 |
+
|
| 84 |
+
# NaN handling: Create validity mask
|
| 85 |
+
valid_mask = (~cp.isnan(diff).any(axis=3) &
|
| 86 |
+
~cp.isnan(var_sum).any(axis=3) &
|
| 87 |
+
(var_sum > tol).all(axis=3)) # Shape: (n_subjects, 150, 150)
|
| 88 |
+
|
| 89 |
+
# Compute quadratic form for valid entries only
|
| 90 |
+
# Use where to avoid division by zero
|
| 91 |
+
quad_terms = cp.where(valid_mask[:, :, :, cp.newaxis],
|
| 92 |
+
diff * diff / var_sum,
|
| 93 |
+
0.0) # Shape: (n_subjects, 150, 150, n_features)
|
| 94 |
+
|
| 95 |
+
# Sum over features
|
| 96 |
+
quad_sum = cp.sum(quad_terms, axis=3) # Shape: (n_subjects, 150, 150)
|
| 97 |
+
|
| 98 |
+
# Apply exponential with underflow protection
|
| 99 |
+
# Only compute exp for valid entries with reasonable values
|
| 100 |
+
safe_exp_mask = valid_mask & (quad_sum * 0.5 <= 20.0)
|
| 101 |
+
|
| 102 |
+
overlap_batch_gpu = cp.where(safe_exp_mask,
|
| 103 |
+
cp.exp(-0.5 * quad_sum),
|
| 104 |
+
0.0)
|
| 105 |
+
|
| 106 |
+
# Apply biomechanical filtering if requested
|
| 107 |
+
if biomechanical_filter:
|
| 108 |
+
overlap_batch_gpu = _apply_biomechanical_filter_gpu(
|
| 109 |
+
overlap_batch_gpu, means1_gpu, vars1_gpu, means2_gpu, vars2_gpu, tol
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Transfer back to CPU - single transfer
|
| 113 |
+
result = cp.asnumpy(overlap_batch_gpu).astype(np.float64)
|
| 114 |
+
|
| 115 |
+
# Final clipping on CPU
|
| 116 |
+
np.clip(result, 0.0, 1.0, out=result)
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _apply_biomechanical_filter_gpu(overlap_batch, means1_batch, vars1_batch,
|
| 121 |
+
means2_batch, vars2_batch, tol):
|
| 122 |
+
"""Apply biomechanical filtering on GPU using vectorized operations."""
|
| 123 |
+
n_subjects = overlap_batch.shape[0]
|
| 124 |
+
negligible_threshold = 0.1
|
| 125 |
+
ampable_threshold = 0.2
|
| 126 |
+
ci_factor = 1.96
|
| 127 |
+
|
| 128 |
+
# Only process first feature (torque) for biomechanical filtering
|
| 129 |
+
means1_torque = means1_batch[:, :, 0] # Shape: (n_subjects, 150)
|
| 130 |
+
means2_torque = means2_batch[:, :, 0]
|
| 131 |
+
vars1_torque = vars1_batch[:, :, 0]
|
| 132 |
+
vars2_torque = vars2_batch[:, :, 0]
|
| 133 |
+
|
| 134 |
+
# Vectorized std and CI calculations
|
| 135 |
+
std1 = cp.sqrt(vars1_torque)
|
| 136 |
+
std2 = cp.sqrt(vars2_torque)
|
| 137 |
+
|
| 138 |
+
ci_lo1 = means1_torque - ci_factor * std1
|
| 139 |
+
ci_hi1 = means1_torque + ci_factor * std1
|
| 140 |
+
ci_lo2 = means2_torque - ci_factor * std2
|
| 141 |
+
ci_hi2 = means2_torque + ci_factor * std2
|
| 142 |
+
|
| 143 |
+
# Vectorized mask computation
|
| 144 |
+
negligible1 = ((ci_lo1 >= -negligible_threshold) &
|
| 145 |
+
(ci_hi1 <= negligible_threshold)) # Shape: (n_subjects, 150)
|
| 146 |
+
negligible2 = ((ci_lo2 >= -negligible_threshold) &
|
| 147 |
+
(ci_hi2 <= negligible_threshold))
|
| 148 |
+
ampable1 = cp.abs(means1_torque) > ampable_threshold
|
| 149 |
+
ampable2 = cp.abs(means2_torque) > ampable_threshold
|
| 150 |
+
|
| 151 |
+
# Broadcast to phase pair dimensions using newaxis
|
| 152 |
+
neg1_exp = negligible1[:, :, cp.newaxis] # (n_subjects, 150, 1)
|
| 153 |
+
neg2_exp = negligible2[:, cp.newaxis, :] # (n_subjects, 1, 150)
|
| 154 |
+
amp1_exp = ampable1[:, :, cp.newaxis]
|
| 155 |
+
amp2_exp = ampable2[:, cp.newaxis, :]
|
| 156 |
+
|
| 157 |
+
# Three-level filtering masks
|
| 158 |
+
# Negligible-negligible: Both torques are negligible
|
| 159 |
+
m0 = neg1_exp & neg2_exp # Shape: (n_subjects, 150, 150)
|
| 160 |
+
|
| 161 |
+
# Amplitude conflicts: One negligible, other ampable
|
| 162 |
+
m1 = ((neg1_exp & amp2_exp) | (neg2_exp & amp1_exp))
|
| 163 |
+
|
| 164 |
+
# Sign reversal cases: Neither negligible-negligible nor amplitude conflict
|
| 165 |
+
m2 = ~(m0 | m1)
|
| 166 |
+
|
| 167 |
+
# Apply negligible-negligible rule (set to 1.0)
|
| 168 |
+
overlap_batch = cp.where(m0, 1.0, overlap_batch)
|
| 169 |
+
|
| 170 |
+
# Apply sign reversal filtering for m2 cases
|
| 171 |
+
if cp.any(m2):
|
| 172 |
+
# Get indices where filtering is needed
|
| 173 |
+
s_idx, i_idx, j_idx = cp.where(m2)
|
| 174 |
+
|
| 175 |
+
if len(s_idx) > 0:
|
| 176 |
+
# Vectorized probability calculations
|
| 177 |
+
std1_safe = cp.maximum(std1, tol)
|
| 178 |
+
std2_safe = cp.maximum(std2, tol)
|
| 179 |
+
z1 = means1_torque / std1_safe
|
| 180 |
+
z2 = means2_torque / std2_safe
|
| 181 |
+
|
| 182 |
+
# Normal CDF approximation (vectorized)
|
| 183 |
+
def norm_cdf_gpu(x):
|
| 184 |
+
# Abramowitz and Stegun approximation
|
| 185 |
+
t = 1.0 / (1.0 + 0.2316419 * cp.abs(x))
|
| 186 |
+
d = 0.3989423 * cp.exp(-x * x / 2.0)
|
| 187 |
+
prob = d * t * (0.3193815 + t * (-0.3565638 +
|
| 188 |
+
t * (1.781478 + t * (-1.821256 + t * 1.330274))))
|
| 189 |
+
return cp.where(x > 0, 1.0 - prob, prob)
|
| 190 |
+
|
| 191 |
+
Ppos1 = norm_cdf_gpu(z1)
|
| 192 |
+
Ppos2 = norm_cdf_gpu(z2)
|
| 193 |
+
|
| 194 |
+
# Sign-mismatch probability for selected indices
|
| 195 |
+
Pdiff_sign = (Ppos1[s_idx, i_idx] * (1.0 - Ppos2[s_idx, j_idx]) +
|
| 196 |
+
(1.0 - Ppos1[s_idx, i_idx]) * Ppos2[s_idx, j_idx])
|
| 197 |
+
|
| 198 |
+
# Mean-difference penalty (vectorized ramp function)
|
| 199 |
+
mean_diff = cp.abs(means1_torque[s_idx, i_idx] - means2_torque[s_idx, j_idx])
|
| 200 |
+
s_thresh, e_thresh = 0.2, 0.5
|
| 201 |
+
|
| 202 |
+
# Linear ramp penalty
|
| 203 |
+
penalty = cp.clip((mean_diff - s_thresh) / (e_thresh - s_thresh), 0.0, 1.0)
|
| 204 |
+
|
| 205 |
+
# Combine penalties
|
| 206 |
+
Pdiff = cp.maximum(Pdiff_sign, penalty)
|
| 207 |
+
|
| 208 |
+
# Apply penalty to overlaps
|
| 209 |
+
current_overlaps = overlap_batch[s_idx, i_idx, j_idx]
|
| 210 |
+
output_diff = 1.0 - current_overlaps
|
| 211 |
+
scaled_output_diff = output_diff * Pdiff
|
| 212 |
+
overlap_batch[s_idx, i_idx, j_idx] = 1.0 - scaled_output_diff
|
| 213 |
+
|
| 214 |
+
return overlap_batch
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def compute_overlap_batch_gpu_chunked(means1_batch, vars1_batch, means2_batch, vars2_batch,
|
| 218 |
+
chunk_size=None, **kwargs):
|
| 219 |
+
"""
|
| 220 |
+
Chunked GPU processing for very large datasets that don't fit in GPU memory.
|
| 221 |
+
|
| 222 |
+
Automatically determines optimal chunk size based on available GPU memory.
|
| 223 |
+
"""
|
| 224 |
+
if not GPU_AVAILABLE:
|
| 225 |
+
raise RuntimeError("CuPy not available for GPU computation")
|
| 226 |
+
|
| 227 |
+
n_subjects = means1_batch.shape[0]
|
| 228 |
+
|
| 229 |
+
if chunk_size is None:
|
| 230 |
+
# Estimate chunk size based on GPU memory
|
| 231 |
+
mempool = cp.get_default_memory_pool()
|
| 232 |
+
available_memory = mempool.free_bytes()
|
| 233 |
+
|
| 234 |
+
# Rough estimate: each subject needs ~150*150*4 bytes for overlap + input arrays
|
| 235 |
+
bytes_per_subject = 150 * 150 * 4 * 6 # 6 arrays (means1, vars1, means2, vars2, overlap, temp)
|
| 236 |
+
estimated_chunk_size = max(1, int(available_memory * 0.8 // bytes_per_subject))
|
| 237 |
+
chunk_size = min(estimated_chunk_size, n_subjects)
|
| 238 |
+
|
| 239 |
+
print(f"🔧 Auto-determined GPU chunk size: {chunk_size} subjects")
|
| 240 |
+
|
| 241 |
+
if chunk_size >= n_subjects:
|
| 242 |
+
# Process all at once
|
| 243 |
+
return compute_overlap_batch_gpu(means1_batch, vars1_batch,
|
| 244 |
+
means2_batch, vars2_batch, **kwargs)
|
| 245 |
+
|
| 246 |
+
# Process in chunks
|
| 247 |
+
results = []
|
| 248 |
+
for i in range(0, n_subjects, chunk_size):
|
| 249 |
+
end_idx = min(i + chunk_size, n_subjects)
|
| 250 |
+
|
| 251 |
+
chunk_result = compute_overlap_batch_gpu(
|
| 252 |
+
means1_batch[i:end_idx],
|
| 253 |
+
vars1_batch[i:end_idx],
|
| 254 |
+
means2_batch[i:end_idx],
|
| 255 |
+
vars2_batch[i:end_idx],
|
| 256 |
+
**kwargs
|
| 257 |
+
)
|
| 258 |
+
results.append(chunk_result)
|
| 259 |
+
|
| 260 |
+
return np.concatenate(results, axis=0)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def benchmark_gpu_vs_cpu():
|
| 264 |
+
"""Benchmark GPU vs CPU performance on sample data."""
|
| 265 |
+
if not GPU_AVAILABLE:
|
| 266 |
+
print("GPU not available for benchmarking")
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
import time
|
| 270 |
+
|
| 271 |
+
# Create test data
|
| 272 |
+
n_subjects = 10
|
| 273 |
+
n_features = 4
|
| 274 |
+
|
| 275 |
+
print(f"🔧 Benchmarking with {n_subjects} subjects, {n_features} features...")
|
| 276 |
+
|
| 277 |
+
means1 = np.random.randn(n_subjects, 150, n_features).astype(np.float32)
|
| 278 |
+
vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)).astype(np.float32) + 0.1
|
| 279 |
+
means2 = np.random.randn(n_subjects, 150, n_features).astype(np.float32)
|
| 280 |
+
vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)).astype(np.float32) + 0.1
|
| 281 |
+
|
| 282 |
+
# Warm up GPU
|
| 283 |
+
if GPU_AVAILABLE:
|
| 284 |
+
_ = compute_overlap_batch_gpu(means1[:2], vars1[:2], means2[:2], vars2[:2])
|
| 285 |
+
|
| 286 |
+
# Benchmark GPU
|
| 287 |
+
if GPU_AVAILABLE:
|
| 288 |
+
start = time.time()
|
| 289 |
+
result_gpu = compute_overlap_batch_gpu(means1, vars1, means2, vars2)
|
| 290 |
+
gpu_time = time.time() - start
|
| 291 |
+
print(f"🚀 GPU time: {gpu_time:.4f} seconds")
|
| 292 |
+
else:
|
| 293 |
+
result_gpu = None
|
| 294 |
+
gpu_time = float('inf')
|
| 295 |
+
|
| 296 |
+
# Benchmark CPU (Numba fallback)
|
| 297 |
+
try:
|
| 298 |
+
from .numba_overlap import compute_overlap_batch
|
| 299 |
+
start = time.time()
|
| 300 |
+
result_cpu = compute_overlap_batch(means1, vars1, means2, vars2)
|
| 301 |
+
cpu_time = time.time() - start
|
| 302 |
+
print(f"🔧 CPU time: {cpu_time:.4f} seconds")
|
| 303 |
+
|
| 304 |
+
if GPU_AVAILABLE and result_gpu is not None:
|
| 305 |
+
speedup = cpu_time / gpu_time
|
| 306 |
+
print(f"📈 GPU Speedup: {speedup:.1f}x")
|
| 307 |
+
|
| 308 |
+
# Check accuracy
|
| 309 |
+
max_diff = np.max(np.abs(result_gpu.astype(np.float64) - result_cpu))
|
| 310 |
+
print(f"🎯 Max difference: {max_diff:.2e}")
|
| 311 |
+
|
| 312 |
+
except ImportError:
|
| 313 |
+
print("❌ Numba not available for CPU comparison")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def compute_overlap_batch_gpu_mega(all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
|
| 317 |
+
valid_mask, tol=1e-12, biomechanical_filter=False):
|
| 318 |
+
"""
|
| 319 |
+
MEGA-BATCH GPU computation: Process ALL task pairs simultaneously.
|
| 320 |
+
|
| 321 |
+
This is the ultimate "throw everything in" approach for maximum GPU utilization.
|
| 322 |
+
Processes hundreds of task pairs × subjects × phase pairs in a single GPU call.
|
| 323 |
+
|
| 324 |
+
Parameters:
|
| 325 |
+
all_means1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 326 |
+
all_vars1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 327 |
+
all_means2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 328 |
+
all_vars2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 329 |
+
valid_mask: np.ndarray shape (n_task_pairs, n_subjects_max) - bool mask for valid subjects
|
| 330 |
+
tol: float, tolerance for variance validity
|
| 331 |
+
biomechanical_filter: bool, apply biomechanical filtering
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
np.ndarray shape (n_task_pairs, n_subjects_max, 150, 150) - overlap values
|
| 335 |
+
"""
|
| 336 |
+
if not GPU_AVAILABLE:
|
| 337 |
+
raise RuntimeError("CuPy not available for mega-batch GPU computation")
|
| 338 |
+
|
| 339 |
+
n_task_pairs, n_subjects_max, n_phases, n_features = all_means1_batch.shape
|
| 340 |
+
|
| 341 |
+
print(f"🚀 GPU Mega-batch: Processing {n_task_pairs} task pairs × {n_subjects_max} subjects × {150*150} phase pairs")
|
| 342 |
+
print(f"📊 Total computations: {n_task_pairs * n_subjects_max * 150 * 150:,}")
|
| 343 |
+
|
| 344 |
+
# Transfer ALL data to GPU in single transfer
|
| 345 |
+
means1_gpu = cp.asarray(all_means1_batch, dtype=cp.float32)
|
| 346 |
+
vars1_gpu = cp.asarray(all_vars1_batch, dtype=cp.float32)
|
| 347 |
+
means2_gpu = cp.asarray(all_means2_batch, dtype=cp.float32)
|
| 348 |
+
vars2_gpu = cp.asarray(all_vars2_batch, dtype=cp.float32)
|
| 349 |
+
valid_gpu = cp.asarray(valid_mask, dtype=cp.bool_)
|
| 350 |
+
|
| 351 |
+
# Pre-allocate output on GPU
|
| 352 |
+
overlap_batch_gpu = cp.zeros((n_task_pairs, n_subjects_max, 150, 150), dtype=cp.float32)
|
| 353 |
+
|
| 354 |
+
# MEGA BROADCASTING: Process ALL task pairs and subjects simultaneously
|
| 355 |
+
# Shape transformations for 5D broadcasting:
|
| 356 |
+
# (n_task_pairs, n_subjects_max, 150, 1, n_features) vs (n_task_pairs, n_subjects_max, 1, 150, n_features)
|
| 357 |
+
|
| 358 |
+
means1_exp = means1_gpu[:, :, :, cp.newaxis, :] # Add phase_j dimension
|
| 359 |
+
vars1_exp = vars1_gpu[:, :, :, cp.newaxis, :]
|
| 360 |
+
means2_exp = means2_gpu[:, :, cp.newaxis, :, :] # Add phase_i dimension
|
| 361 |
+
vars2_exp = vars2_gpu[:, :, cp.newaxis, :, :]
|
| 362 |
+
|
| 363 |
+
# Compute ALL differences and variance sums simultaneously
|
| 364 |
+
# Shape: (n_task_pairs, n_subjects_max, 150, 150, n_features)
|
| 365 |
+
diff = means1_exp - means2_exp
|
| 366 |
+
var_sum = vars1_exp + vars2_exp
|
| 367 |
+
|
| 368 |
+
# Create mega validity mask
|
| 369 |
+
# Shape: (n_task_pairs, n_subjects_max, 150, 150)
|
| 370 |
+
subject_valid = valid_gpu[:, :, cp.newaxis, cp.newaxis] # Broadcast to all phase pairs
|
| 371 |
+
|
| 372 |
+
# NaN and variance validity for ALL data simultaneously
|
| 373 |
+
nan_valid = (~cp.isnan(diff).any(axis=4) &
|
| 374 |
+
~cp.isnan(var_sum).any(axis=4) &
|
| 375 |
+
(var_sum > tol).all(axis=4))
|
| 376 |
+
|
| 377 |
+
# Combined validity mask
|
| 378 |
+
full_valid_mask = subject_valid & nan_valid
|
| 379 |
+
|
| 380 |
+
# Compute quadratic form for ALL valid entries
|
| 381 |
+
quad_terms = cp.where(full_valid_mask[:, :, :, :, cp.newaxis],
|
| 382 |
+
diff * diff / var_sum,
|
| 383 |
+
0.0)
|
| 384 |
+
|
| 385 |
+
# Sum over features for ALL task pairs simultaneously
|
| 386 |
+
quad_sum = cp.sum(quad_terms, axis=4) # Shape: (n_task_pairs, n_subjects_max, 150, 150)
|
| 387 |
+
|
| 388 |
+
# Apply exponential with underflow protection
|
| 389 |
+
safe_exp_mask = full_valid_mask & (quad_sum * 0.5 <= 20.0)
|
| 390 |
+
overlap_batch_gpu = cp.where(safe_exp_mask,
|
| 391 |
+
cp.exp(-0.5 * quad_sum),
|
| 392 |
+
0.0)
|
| 393 |
+
|
| 394 |
+
# Apply biomechanical filtering if requested
|
| 395 |
+
if biomechanical_filter:
|
| 396 |
+
overlap_batch_gpu = _apply_biomechanical_filter_gpu_mega(
|
| 397 |
+
overlap_batch_gpu, means1_gpu, vars1_gpu, means2_gpu, vars2_gpu, valid_gpu, tol
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Transfer back to CPU - single transfer for ALL results
|
| 401 |
+
print("📥 Transferring results from GPU...")
|
| 402 |
+
result = cp.asnumpy(overlap_batch_gpu).astype(np.float64)
|
| 403 |
+
|
| 404 |
+
# Final clipping
|
| 405 |
+
np.clip(result, 0.0, 1.0, out=result)
|
| 406 |
+
|
| 407 |
+
print(f"✅ Mega-batch GPU computation complete!")
|
| 408 |
+
return result
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _apply_biomechanical_filter_gpu_mega(overlap_batch, means1_batch, vars1_batch,
|
| 412 |
+
means2_batch, vars2_batch, valid_mask, tol):
|
| 413 |
+
"""Apply biomechanical filtering for mega-batch on GPU."""
|
| 414 |
+
negligible_threshold = 0.1
|
| 415 |
+
ampable_threshold = 0.2
|
| 416 |
+
ci_factor = 1.96
|
| 417 |
+
|
| 418 |
+
n_task_pairs, n_subjects_max = overlap_batch.shape[:2]
|
| 419 |
+
|
| 420 |
+
# Only process first feature (torque) for biomechanical filtering
|
| 421 |
+
means1_torque = means1_batch[:, :, :, 0] # Shape: (n_task_pairs, n_subjects_max, 150)
|
| 422 |
+
means2_torque = means2_batch[:, :, :, 0]
|
| 423 |
+
vars1_torque = vars1_batch[:, :, :, 0]
|
| 424 |
+
vars2_torque = vars2_batch[:, :, :, 0]
|
| 425 |
+
|
| 426 |
+
# Vectorized std and CI calculations for ALL task pairs
|
| 427 |
+
std1 = cp.sqrt(vars1_torque)
|
| 428 |
+
std2 = cp.sqrt(vars2_torque)
|
| 429 |
+
|
| 430 |
+
ci_lo1 = means1_torque - ci_factor * std1
|
| 431 |
+
ci_hi1 = means1_torque + ci_factor * std1
|
| 432 |
+
ci_lo2 = means2_torque - ci_factor * std2
|
| 433 |
+
ci_hi2 = means2_torque + ci_factor * std2
|
| 434 |
+
|
| 435 |
+
# Vectorized mask computation for ALL task pairs
|
| 436 |
+
negligible1 = ((ci_lo1 >= -negligible_threshold) &
|
| 437 |
+
(ci_hi1 <= negligible_threshold))
|
| 438 |
+
negligible2 = ((ci_lo2 >= -negligible_threshold) &
|
| 439 |
+
(ci_hi2 <= negligible_threshold))
|
| 440 |
+
ampable1 = cp.abs(means1_torque) > ampable_threshold
|
| 441 |
+
ampable2 = cp.abs(means2_torque) > ampable_threshold
|
| 442 |
+
|
| 443 |
+
# Broadcast to phase pair dimensions
|
| 444 |
+
# Shape: (n_task_pairs, n_subjects_max, 150, 1)
|
| 445 |
+
neg1_exp = negligible1[:, :, :, cp.newaxis]
|
| 446 |
+
amp1_exp = ampable1[:, :, :, cp.newaxis]
|
| 447 |
+
|
| 448 |
+
# Shape: (n_task_pairs, n_subjects_max, 1, 150)
|
| 449 |
+
neg2_exp = negligible2[:, :, cp.newaxis, :]
|
| 450 |
+
amp2_exp = ampable2[:, :, cp.newaxis, :]
|
| 451 |
+
|
| 452 |
+
# Apply subject validity mask
|
| 453 |
+
valid_exp = valid_mask[:, :, cp.newaxis, cp.newaxis]
|
| 454 |
+
|
| 455 |
+
# Three-level filtering masks for ALL task pairs
|
| 456 |
+
m0 = (neg1_exp & neg2_exp) & valid_exp # Negligible-negligible
|
| 457 |
+
m1 = ((neg1_exp & amp2_exp) | (neg2_exp & amp1_exp)) & valid_exp # Amplitude conflicts
|
| 458 |
+
m2 = ~(m0 | m1) & valid_exp # Sign reversal cases
|
| 459 |
+
|
| 460 |
+
# Apply negligible-negligible rule
|
| 461 |
+
overlap_batch = cp.where(m0, 1.0, overlap_batch)
|
| 462 |
+
|
| 463 |
+
# Apply sign reversal filtering for m2 cases (if any exist)
|
| 464 |
+
if cp.any(m2):
|
| 465 |
+
# For mega-batch, we'll use a simplified linear ramp for performance
|
| 466 |
+
# (Full probability calculation would be too expensive for this scale)
|
| 467 |
+
|
| 468 |
+
# Get phase indices for m2 cases
|
| 469 |
+
t_idx, s_idx, i_idx, j_idx = cp.where(m2)
|
| 470 |
+
|
| 471 |
+
if len(t_idx) > 0:
|
| 472 |
+
# Mean-difference penalty (vectorized)
|
| 473 |
+
mean_diff = cp.abs(means1_torque[t_idx, s_idx, i_idx] -
|
| 474 |
+
means2_torque[t_idx, s_idx, j_idx])
|
| 475 |
+
|
| 476 |
+
# Linear ramp penalty (simplified for mega-batch performance)
|
| 477 |
+
s_thresh, e_thresh = 0.2, 0.5
|
| 478 |
+
penalty = cp.clip((mean_diff - s_thresh) / (e_thresh - s_thresh), 0.0, 1.0)
|
| 479 |
+
|
| 480 |
+
# Apply penalty to overlaps
|
| 481 |
+
current_overlaps = overlap_batch[t_idx, s_idx, i_idx, j_idx]
|
| 482 |
+
output_diff = 1.0 - current_overlaps
|
| 483 |
+
scaled_output_diff = output_diff * penalty
|
| 484 |
+
overlap_batch[t_idx, s_idx, i_idx, j_idx] = 1.0 - scaled_output_diff
|
| 485 |
+
|
| 486 |
+
return overlap_batch
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def estimate_mega_batch_memory(n_task_pairs, n_subjects_max, n_features):
|
| 490 |
+
"""
|
| 491 |
+
Estimate GPU memory requirements for mega-batch processing.
|
| 492 |
+
|
| 493 |
+
CRITICAL: This accounts for the 5D broadcasting that happens during GPU computation:
|
| 494 |
+
- Input: (n_task_pairs, n_subjects_max, 150, n_features)
|
| 495 |
+
- Broadcast to: (n_task_pairs, n_subjects_max, 150, 150, n_features) for computation
|
| 496 |
+
- The 150x150 expansion is the killer for large feature counts!
|
| 497 |
+
"""
|
| 498 |
+
# Input arrays (pre-broadcasting)
|
| 499 |
+
input_size = 4 * n_task_pairs * n_subjects_max * 150 * n_features * 4 # 4 input arrays
|
| 500 |
+
|
| 501 |
+
# Output array
|
| 502 |
+
output_size = n_task_pairs * n_subjects_max * 150 * 150 * 4
|
| 503 |
+
|
| 504 |
+
# CRITICAL: 5D broadcasting intermediate tensors during computation
|
| 505 |
+
# These are the real memory hogs: (n_task_pairs, n_subjects_max, 150, 150, n_features)
|
| 506 |
+
broadcast_5d_size = n_task_pairs * n_subjects_max * 150 * 150 * n_features * 4
|
| 507 |
+
# We need multiple of these simultaneously (diff, var_sum, quad_terms, etc.)
|
| 508 |
+
intermediate_5d_size = broadcast_5d_size * 4 # Conservative estimate: 4 large 5D tensors
|
| 509 |
+
|
| 510 |
+
total_bytes = input_size + output_size + intermediate_5d_size
|
| 511 |
+
total_gb = total_bytes / (1024**3)
|
| 512 |
+
|
| 513 |
+
return total_gb
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def get_available_gpu_memory_gb():
|
| 517 |
+
"""Get available GPU memory in GB."""
|
| 518 |
+
if not GPU_AVAILABLE:
|
| 519 |
+
return 0.0
|
| 520 |
+
|
| 521 |
+
try:
|
| 522 |
+
# Get GPU memory info directly from CuPy device
|
| 523 |
+
device = cp.cuda.Device()
|
| 524 |
+
total_mem = device.mem_info[1] # Total memory
|
| 525 |
+
used_mem = device.mem_info[1] - device.mem_info[0] # Used = Total - Free
|
| 526 |
+
|
| 527 |
+
# Use 70% of free memory as safety margin
|
| 528 |
+
free_mem = device.mem_info[0] * 0.7
|
| 529 |
+
available_gb = free_mem / (1024**3)
|
| 530 |
+
|
| 531 |
+
return max(0.5, available_gb) # Ensure at least 0.5GB for minimal chunking
|
| 532 |
+
except:
|
| 533 |
+
# Fallback: assume 5GB available for RTX series
|
| 534 |
+
return 5.0
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, target_memory_gb=None):
|
| 538 |
+
"""Calculate optimal chunk size based on available GPU memory."""
|
| 539 |
+
if not GPU_AVAILABLE:
|
| 540 |
+
return 1
|
| 541 |
+
|
| 542 |
+
if target_memory_gb is None:
|
| 543 |
+
target_memory_gb = get_available_gpu_memory_gb()
|
| 544 |
+
|
| 545 |
+
# Binary search for optimal chunk size
|
| 546 |
+
min_chunk = 1
|
| 547 |
+
max_chunk = total_pairs
|
| 548 |
+
optimal_chunk = 1
|
| 549 |
+
|
| 550 |
+
while min_chunk <= max_chunk:
|
| 551 |
+
mid_chunk = (min_chunk + max_chunk) // 2
|
| 552 |
+
memory_needed = estimate_mega_batch_memory(mid_chunk, n_subjects_max, n_features)
|
| 553 |
+
|
| 554 |
+
if memory_needed <= target_memory_gb:
|
| 555 |
+
optimal_chunk = mid_chunk
|
| 556 |
+
min_chunk = mid_chunk + 1
|
| 557 |
+
else:
|
| 558 |
+
max_chunk = mid_chunk - 1
|
| 559 |
+
|
| 560 |
+
# Ensure at least 1 task pair per chunk
|
| 561 |
+
return max(1, optimal_chunk)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def get_available_ram_gb():
|
| 565 |
+
"""Get available system RAM in GB."""
|
| 566 |
+
try:
|
| 567 |
+
import psutil
|
| 568 |
+
available_ram_gb = psutil.virtual_memory().available / (1024**3)
|
| 569 |
+
return available_ram_gb
|
| 570 |
+
except ImportError:
|
| 571 |
+
# Fallback: assume 16GB available (conservative)
|
| 572 |
+
return 16.0
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def calculate_ram_max_chunk_size(n_subjects_max, n_features, available_ram_gb):
|
| 576 |
+
"""Calculate maximum chunk size based on available RAM for numpy arrays."""
|
| 577 |
+
# Each chunk needs 4 arrays: all_means1, all_vars1, all_means2, all_vars2
|
| 578 |
+
# Shape per array: (chunk_size, n_subjects_max, 150, n_features)
|
| 579 |
+
# Each element: 4 bytes (float32)
|
| 580 |
+
bytes_per_task_pair = 4 * n_subjects_max * 150 * n_features * 4 # 4 arrays × 4 bytes
|
| 581 |
+
|
| 582 |
+
# Use 70% of available RAM as safety margin
|
| 583 |
+
safe_ram_bytes = available_ram_gb * 0.7 * (1024**3)
|
| 584 |
+
|
| 585 |
+
max_chunk_size = int(safe_ram_bytes / bytes_per_task_pair)
|
| 586 |
+
return max(1, max_chunk_size)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def calculate_optimal_chunk_size_dual_constraint(total_pairs, n_subjects_max, n_features):
|
| 590 |
+
"""
|
| 591 |
+
Calculate optimal chunk size considering BOTH GPU memory and system RAM constraints.
|
| 592 |
+
|
| 593 |
+
This prevents out-of-memory errors by respecting both:
|
| 594 |
+
1. GPU memory limits (for CuPy processing)
|
| 595 |
+
2. System RAM limits (for numpy array allocation)
|
| 596 |
+
|
| 597 |
+
CRITICAL: For very large feature counts (>100), the 5D broadcasting becomes
|
| 598 |
+
prohibitively expensive, so we use much more conservative estimates.
|
| 599 |
+
|
| 600 |
+
Returns the minimum chunk size that satisfies both constraints.
|
| 601 |
+
"""
|
| 602 |
+
if not GPU_AVAILABLE:
|
| 603 |
+
return 1
|
| 604 |
+
|
| 605 |
+
# Get available memory for both constraints
|
| 606 |
+
gpu_memory_gb = get_available_gpu_memory_gb()
|
| 607 |
+
ram_memory_gb = get_available_ram_gb()
|
| 608 |
+
|
| 609 |
+
# CRITICAL: For large feature counts, the 5D broadcasting dominates memory usage
|
| 610 |
+
# We need to be much more conservative
|
| 611 |
+
if n_features > 100:
|
| 612 |
+
print(f"⚠️ Large feature count ({n_features}) detected - using conservative chunking")
|
| 613 |
+
# For large features, memory usage scales roughly with features^2 due to broadcasting
|
| 614 |
+
# Use a much smaller base and scale down aggressively
|
| 615 |
+
feature_penalty = (n_features / 100) ** 1.5 # Exponential penalty
|
| 616 |
+
conservative_gpu_memory = gpu_memory_gb / feature_penalty
|
| 617 |
+
conservative_ram_memory = ram_memory_gb / (feature_penalty * 0.5) # RAM less affected
|
| 618 |
+
|
| 619 |
+
gpu_max_chunk = calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, conservative_gpu_memory)
|
| 620 |
+
ram_max_chunk = calculate_ram_max_chunk_size(n_subjects_max, n_features, conservative_ram_memory)
|
| 621 |
+
else:
|
| 622 |
+
# Normal calculation for reasonable feature counts
|
| 623 |
+
gpu_max_chunk = calculate_optimal_chunk_size(total_pairs, n_subjects_max, n_features, gpu_memory_gb)
|
| 624 |
+
ram_max_chunk = calculate_ram_max_chunk_size(n_subjects_max, n_features, ram_memory_gb)
|
| 625 |
+
|
| 626 |
+
# Use the most restrictive constraint
|
| 627 |
+
optimal_chunk = min(gpu_max_chunk, ram_max_chunk, total_pairs)
|
| 628 |
+
|
| 629 |
+
print(f"🔧 Dual-constraint analysis:")
|
| 630 |
+
print(f" GPU memory: {gpu_memory_gb:.2f} GB → max {gpu_max_chunk} pairs")
|
| 631 |
+
print(f" RAM memory: {ram_memory_gb:.2f} GB → max {ram_max_chunk} pairs")
|
| 632 |
+
print(f" Using most restrictive: {optimal_chunk} pairs per chunk")
|
| 633 |
+
|
| 634 |
+
# For very large feature counts, ensure we don't go too high
|
| 635 |
+
if n_features > 100:
|
| 636 |
+
# Cap at a reasonable maximum for large feature counts
|
| 637 |
+
max_safe_chunk = max(1, int(50000 / n_features)) # Rough heuristic
|
| 638 |
+
optimal_chunk = min(optimal_chunk, max_safe_chunk)
|
| 639 |
+
if optimal_chunk == max_safe_chunk:
|
| 640 |
+
print(f" 🔒 Capped at {optimal_chunk} pairs due to large feature count")
|
| 641 |
+
|
| 642 |
+
return max(1, optimal_chunk)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def compute_overlap_batch_gpu_mega_chunked(all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
|
| 646 |
+
valid_mask, tol=1e-12, biomechanical_filter=False, progress_callback=None):
|
| 647 |
+
"""
|
| 648 |
+
Chunked mega-batch GPU computation: Process task pairs in optimal chunks.
|
| 649 |
+
|
| 650 |
+
Automatically determines chunk size based on available GPU memory and processes
|
| 651 |
+
task pairs in chunks while maintaining all subjects per chunk for maximum efficiency.
|
| 652 |
+
|
| 653 |
+
Parameters:
|
| 654 |
+
all_means1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 655 |
+
all_vars1_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 656 |
+
all_means2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 657 |
+
all_vars2_batch: np.ndarray shape (n_task_pairs, n_subjects_max, 150, n_features)
|
| 658 |
+
valid_mask: np.ndarray shape (n_task_pairs, n_subjects_max) - bool mask for valid subjects
|
| 659 |
+
tol: float, tolerance for variance validity
|
| 660 |
+
biomechanical_filter: bool, apply biomechanical filtering
|
| 661 |
+
progress_callback: callable, progress reporting function
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
np.ndarray shape (n_task_pairs, n_subjects_max, 150, 150) - overlap values
|
| 665 |
+
"""
|
| 666 |
+
if not GPU_AVAILABLE:
|
| 667 |
+
raise RuntimeError("CuPy not available for chunked mega-batch GPU computation")
|
| 668 |
+
|
| 669 |
+
n_task_pairs, n_subjects_max, n_phases, n_features = all_means1_batch.shape
|
| 670 |
+
|
| 671 |
+
# Calculate optimal chunk size using dual constraints (GPU + RAM)
|
| 672 |
+
chunk_size = calculate_optimal_chunk_size_dual_constraint(n_task_pairs, n_subjects_max, n_features)
|
| 673 |
+
|
| 674 |
+
print(f"🔧 Chunking Strategy:")
|
| 675 |
+
print(f" Total task pairs: {n_task_pairs:,}")
|
| 676 |
+
print(f" Optimal chunk size: {chunk_size:,} task pairs")
|
| 677 |
+
print(f" Number of chunks: {(n_task_pairs + chunk_size - 1) // chunk_size}")
|
| 678 |
+
|
| 679 |
+
# Try single batch first, but catch out-of-memory errors
|
| 680 |
+
if chunk_size >= n_task_pairs:
|
| 681 |
+
print("🚀 Attempting single mega-batch processing...")
|
| 682 |
+
try:
|
| 683 |
+
return compute_overlap_batch_gpu_mega(
|
| 684 |
+
all_means1_batch, all_vars1_batch, all_means2_batch, all_vars2_batch,
|
| 685 |
+
valid_mask, tol, biomechanical_filter
|
| 686 |
+
)
|
| 687 |
+
except Exception as e:
|
| 688 |
+
if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
|
| 689 |
+
print(f"⚠️ Single batch failed with memory error, forcing chunking...")
|
| 690 |
+
# Recalculate with much more conservative memory estimate
|
| 691 |
+
conservative_memory = min(available_memory * 0.3, 3.0) # Use max 3GB or 30% of available
|
| 692 |
+
chunk_size = calculate_optimal_chunk_size(n_task_pairs, n_subjects_max, n_features, conservative_memory)
|
| 693 |
+
chunk_size = max(1, chunk_size // 2) # Further reduce chunk size
|
| 694 |
+
print(f"🔧 Fallback chunk size: {chunk_size} pairs (conservative estimate)")
|
| 695 |
+
else:
|
| 696 |
+
raise e
|
| 697 |
+
|
| 698 |
+
# Process in chunks
|
| 699 |
+
print(f"🔄 Processing {n_task_pairs:,} task pairs in chunks of {chunk_size:,}...")
|
| 700 |
+
results = []
|
| 701 |
+
|
| 702 |
+
for chunk_start in range(0, n_task_pairs, chunk_size):
|
| 703 |
+
chunk_end = min(chunk_start + chunk_size, n_task_pairs)
|
| 704 |
+
chunk_num = len(results) + 1
|
| 705 |
+
total_chunks = (n_task_pairs + chunk_size - 1) // chunk_size
|
| 706 |
+
|
| 707 |
+
print(f"🚀 Processing chunk {chunk_num}/{total_chunks} (task pairs {chunk_start}:{chunk_end})...")
|
| 708 |
+
|
| 709 |
+
# Extract chunk data
|
| 710 |
+
chunk_means1 = all_means1_batch[chunk_start:chunk_end]
|
| 711 |
+
chunk_vars1 = all_vars1_batch[chunk_start:chunk_end]
|
| 712 |
+
chunk_means2 = all_means2_batch[chunk_start:chunk_end]
|
| 713 |
+
chunk_vars2 = all_vars2_batch[chunk_start:chunk_end]
|
| 714 |
+
chunk_valid = valid_mask[chunk_start:chunk_end]
|
| 715 |
+
|
| 716 |
+
# Process chunk with additional error handling
|
| 717 |
+
import time
|
| 718 |
+
start_time = time.time()
|
| 719 |
+
try:
|
| 720 |
+
chunk_result = compute_overlap_batch_gpu_mega(
|
| 721 |
+
chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
|
| 722 |
+
chunk_valid, tol, biomechanical_filter
|
| 723 |
+
)
|
| 724 |
+
chunk_time = time.time() - start_time
|
| 725 |
+
except Exception as e:
|
| 726 |
+
if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
|
| 727 |
+
print(f" ⚠️ Chunk {chunk_num} still too large, attempting progressive reduction...")
|
| 728 |
+
# Progressive reduction: try smaller and smaller chunks
|
| 729 |
+
chunk_result = _process_chunk_with_progressive_reduction(
|
| 730 |
+
chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
|
| 731 |
+
chunk_valid, tol, biomechanical_filter, chunk_num
|
| 732 |
+
)
|
| 733 |
+
chunk_time = time.time() - start_time
|
| 734 |
+
else:
|
| 735 |
+
raise e
|
| 736 |
+
|
| 737 |
+
results.append(chunk_result)
|
| 738 |
+
|
| 739 |
+
# Progress reporting
|
| 740 |
+
progress = (chunk_end) / n_task_pairs
|
| 741 |
+
if progress_callback:
|
| 742 |
+
progress_callback(progress * 0.9) # Save 10% for final aggregation
|
| 743 |
+
|
| 744 |
+
# Performance metrics
|
| 745 |
+
chunk_pairs = chunk_end - chunk_start
|
| 746 |
+
valid_computations = np.sum(chunk_valid) * 150 * 150
|
| 747 |
+
throughput = valid_computations / chunk_time if chunk_time > 0 else 0
|
| 748 |
+
|
| 749 |
+
print(f" ✅ Chunk {chunk_num} complete: {chunk_time:.2f}s, {throughput:,.0f} computations/sec")
|
| 750 |
+
|
| 751 |
+
# Memory cleanup
|
| 752 |
+
if GPU_AVAILABLE:
|
| 753 |
+
cp.get_default_memory_pool().free_all_blocks()
|
| 754 |
+
|
| 755 |
+
print("🔧 Combining chunk results...")
|
| 756 |
+
final_result = np.concatenate(results, axis=0)
|
| 757 |
+
|
| 758 |
+
if progress_callback:
|
| 759 |
+
progress_callback(1.0)
|
| 760 |
+
|
| 761 |
+
print(f"✅ Chunked mega-batch processing complete!")
|
| 762 |
+
print(f"📊 Final result shape: {final_result.shape}")
|
| 763 |
+
|
| 764 |
+
return final_result
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def _process_chunk_with_progressive_reduction(chunk_means1, chunk_vars1, chunk_means2, chunk_vars2,
|
| 768 |
+
chunk_valid, tol, biomechanical_filter, chunk_num):
|
| 769 |
+
"""
|
| 770 |
+
Process a chunk with progressive size reduction if out-of-memory errors occur.
|
| 771 |
+
|
| 772 |
+
Tries progressively smaller sub-chunks until successful or reaches minimum size.
|
| 773 |
+
"""
|
| 774 |
+
chunk_size = chunk_means1.shape[0]
|
| 775 |
+
|
| 776 |
+
# Try progressively smaller sub-chunks: 50%, 25%, 12.5%, etc.
|
| 777 |
+
reduction_factors = [0.5, 0.25, 0.125, 0.0625] # Down to 1/16th
|
| 778 |
+
|
| 779 |
+
for factor in reduction_factors:
|
| 780 |
+
sub_chunk_size = max(1, int(chunk_size * factor))
|
| 781 |
+
print(f" 🔄 Trying sub-chunk size: {sub_chunk_size} pairs ({factor*100:.1f}% of original)")
|
| 782 |
+
|
| 783 |
+
try:
|
| 784 |
+
# Process the chunk in sub-chunks
|
| 785 |
+
sub_results = []
|
| 786 |
+
for start_idx in range(0, chunk_size, sub_chunk_size):
|
| 787 |
+
end_idx = min(start_idx + sub_chunk_size, chunk_size)
|
| 788 |
+
|
| 789 |
+
sub_result = compute_overlap_batch_gpu_mega(
|
| 790 |
+
chunk_means1[start_idx:end_idx],
|
| 791 |
+
chunk_vars1[start_idx:end_idx],
|
| 792 |
+
chunk_means2[start_idx:end_idx],
|
| 793 |
+
chunk_vars2[start_idx:end_idx],
|
| 794 |
+
chunk_valid[start_idx:end_idx],
|
| 795 |
+
tol, biomechanical_filter
|
| 796 |
+
)
|
| 797 |
+
sub_results.append(sub_result)
|
| 798 |
+
|
| 799 |
+
# Clear GPU memory between sub-chunks
|
| 800 |
+
if GPU_AVAILABLE:
|
| 801 |
+
cp.get_default_memory_pool().free_all_blocks()
|
| 802 |
+
|
| 803 |
+
# Combine all sub-results
|
| 804 |
+
final_result = np.concatenate(sub_results, axis=0)
|
| 805 |
+
print(f" ✅ Progressive reduction successful with {sub_chunk_size}-pair sub-chunks")
|
| 806 |
+
return final_result
|
| 807 |
+
|
| 808 |
+
except Exception as e:
|
| 809 |
+
if "OutOfMemoryError" in str(type(e)) or "out of memory" in str(e).lower():
|
| 810 |
+
print(f" ❌ Sub-chunk size {sub_chunk_size} still too large")
|
| 811 |
+
continue
|
| 812 |
+
else:
|
| 813 |
+
raise e
|
| 814 |
+
|
| 815 |
+
# If all reduction attempts failed, we need to fall back to sequential processing
|
| 816 |
+
# Processing one pair at a time with GPU overhead is actually slower than CPU
|
| 817 |
+
print(f" ❌ All reduction attempts failed - chunk too large for GPU mega-batch")
|
| 818 |
+
print(f" 💡 Recommendation: Use smaller time windows or switch to sequential processing")
|
| 819 |
+
print(f" 🔄 Falling back to CPU-based processing for this chunk...")
|
| 820 |
+
|
| 821 |
+
# Fall back to CPU processing for this chunk
|
| 822 |
+
try:
|
| 823 |
+
from .numba_overlap import compute_overlap_batch_numba_ultra_fast
|
| 824 |
+
|
| 825 |
+
# Process on CPU using Numba (much faster than single GPU pairs)
|
| 826 |
+
cpu_results = []
|
| 827 |
+
for i in range(chunk_size):
|
| 828 |
+
means1_i = chunk_means1[i] # Shape: (n_subjects, 150, n_features)
|
| 829 |
+
vars1_i = chunk_vars1[i]
|
| 830 |
+
means2_i = chunk_means2[i]
|
| 831 |
+
vars2_i = chunk_vars2[i]
|
| 832 |
+
valid_i = chunk_valid[i] # Shape: (n_subjects,)
|
| 833 |
+
|
| 834 |
+
# Process valid subjects only
|
| 835 |
+
valid_indices = np.where(valid_i)[0]
|
| 836 |
+
if len(valid_indices) > 0:
|
| 837 |
+
cpu_result = compute_overlap_batch_numba_ultra_fast(
|
| 838 |
+
means1_i[valid_indices], vars1_i[valid_indices],
|
| 839 |
+
means2_i[valid_indices], vars2_i[valid_indices]
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
# Reshape to expected format
|
| 843 |
+
full_result = np.zeros((1, chunk_valid.shape[1], 150, 150), dtype=np.float32)
|
| 844 |
+
full_result[0, valid_indices] = cpu_result
|
| 845 |
+
cpu_results.append(full_result)
|
| 846 |
+
else:
|
| 847 |
+
# No valid subjects
|
| 848 |
+
empty_result = np.zeros((1, chunk_valid.shape[1], 150, 150), dtype=np.float32)
|
| 849 |
+
cpu_results.append(empty_result)
|
| 850 |
+
|
| 851 |
+
final_result = np.concatenate(cpu_results, axis=0)
|
| 852 |
+
print(f" ✅ CPU fallback processing completed")
|
| 853 |
+
return final_result
|
| 854 |
+
|
| 855 |
+
except ImportError:
|
| 856 |
+
print(f" ❌ CPU fallback not available - creating zero results")
|
| 857 |
+
# Last resort: return zeros
|
| 858 |
+
final_result = np.zeros((chunk_size, chunk_valid.shape[1], 150, 150), dtype=np.float32)
|
| 859 |
+
return final_result
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
if __name__ == "__main__":
|
| 863 |
+
print("🧪 Testing GPU overlap calculation...")
|
| 864 |
+
|
| 865 |
+
if GPU_AVAILABLE:
|
| 866 |
+
benchmark_gpu_vs_cpu()
|
| 867 |
+
|
| 868 |
+
# Test mega-batch functionality
|
| 869 |
+
print("\n🚀 Testing mega-batch functionality...")
|
| 870 |
+
|
| 871 |
+
# Create test data for multiple task pairs
|
| 872 |
+
n_task_pairs = 5
|
| 873 |
+
n_subjects_max = 3
|
| 874 |
+
n_features = 4
|
| 875 |
+
|
| 876 |
+
all_means1 = np.random.randn(n_task_pairs, n_subjects_max, 150, n_features).astype(np.float32)
|
| 877 |
+
all_vars1 = np.abs(np.random.randn(n_task_pairs, n_subjects_max, 150, n_features)).astype(np.float32) + 0.1
|
| 878 |
+
all_means2 = np.random.randn(n_task_pairs, n_subjects_max, 150, n_features).astype(np.float32)
|
| 879 |
+
all_vars2 = np.abs(np.random.randn(n_task_pairs, n_subjects_max, 150, n_features)).astype(np.float32) + 0.1
|
| 880 |
+
valid_mask = np.ones((n_task_pairs, n_subjects_max), dtype=bool)
|
| 881 |
+
|
| 882 |
+
import time
|
| 883 |
+
start = time.time()
|
| 884 |
+
result = compute_overlap_batch_gpu_mega(all_means1, all_vars1, all_means2, all_vars2, valid_mask)
|
| 885 |
+
end = time.time()
|
| 886 |
+
|
| 887 |
+
print(f"✅ Mega-batch result shape: {result.shape}")
|
| 888 |
+
print(f"⏱️ Mega-batch time: {end - start:.4f}s")
|
| 889 |
+
print(f"📊 Throughput: {n_task_pairs * n_subjects_max * 150 * 150 / (end - start):,.0f} computations/sec")
|
| 890 |
+
|
| 891 |
+
else:
|
| 892 |
+
print("❌ GPU testing requires CuPy and CUDA")
|
multivariate_gaussian_overlap.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
numba_overlap.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Ultra-fast Numba JIT-compiled implementation of multivariate Gaussian overlap calculation.
|
| 4 |
+
This eliminates all Python overhead and runs at near-C speed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
try:
|
| 9 |
+
import numba
|
| 10 |
+
from numba import jit, prange
|
| 11 |
+
NUMBA_AVAILABLE = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
NUMBA_AVAILABLE = False
|
| 14 |
+
print("Warning: Numba not installed. Install with: pip install numba")
|
| 15 |
+
|
| 16 |
+
if NUMBA_AVAILABLE:
|
| 17 |
+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
|
| 18 |
+
def compute_overlap_batch_numba(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
|
| 19 |
+
"""
|
| 20 |
+
ULTIMATE PERFORMANCE: "It's just differences, divisions, and exponentials!"
|
| 21 |
+
|
| 22 |
+
Eliminates ALL overhead and just does the core mathematical operations:
|
| 23 |
+
overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
|
| 24 |
+
|
| 25 |
+
Parameters:
|
| 26 |
+
means1_batch: (n_subjects, 150, n_features) array of means for task 1
|
| 27 |
+
vars1_batch: (n_subjects, 150, n_features) array of variances for task 1
|
| 28 |
+
means2_batch: (n_subjects, 150, n_features) array of means for task 2
|
| 29 |
+
vars2_batch: (n_subjects, 150, n_features) array of variances for task 2
|
| 30 |
+
tol: Tolerance for variance validity
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
overlap_batch: (n_subjects, 150, 150) array of overlap values
|
| 34 |
+
"""
|
| 35 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 36 |
+
|
| 37 |
+
# Pre-allocate output
|
| 38 |
+
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
|
| 39 |
+
|
| 40 |
+
# CRITICAL OPTIMIZATION: Pre-compute ALL validation outside the main loops
|
| 41 |
+
# This eliminates millions of redundant NaN checks
|
| 42 |
+
valid_phases1 = np.zeros((n_subjects, 150), dtype=numba.boolean)
|
| 43 |
+
valid_phases2 = np.zeros((n_subjects, 150), dtype=numba.boolean)
|
| 44 |
+
|
| 45 |
+
# Pre-compute phase validity for all subjects at once
|
| 46 |
+
for s in prange(n_subjects):
|
| 47 |
+
for i in range(150):
|
| 48 |
+
# Check phase validity once per phase
|
| 49 |
+
valid1 = True
|
| 50 |
+
valid2 = True
|
| 51 |
+
|
| 52 |
+
for f in range(n_features):
|
| 53 |
+
if np.isnan(means1_batch[s, i, f]) or np.isnan(vars1_batch[s, i, f]):
|
| 54 |
+
valid1 = False
|
| 55 |
+
if np.isnan(means2_batch[s, i, f]) or np.isnan(vars2_batch[s, i, f]):
|
| 56 |
+
valid2 = False
|
| 57 |
+
|
| 58 |
+
valid_phases1[s, i] = valid1
|
| 59 |
+
valid_phases2[s, i] = valid2
|
| 60 |
+
|
| 61 |
+
# MAIN COMPUTATION: Process only valid phase pairs
|
| 62 |
+
for s in prange(n_subjects):
|
| 63 |
+
for i in range(150):
|
| 64 |
+
if not valid_phases1[s, i]:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
# Extract data for phase i once (avoid repeated indexing)
|
| 68 |
+
means1_i = means1_batch[s, i]
|
| 69 |
+
vars1_i = vars1_batch[s, i]
|
| 70 |
+
|
| 71 |
+
for j in range(150):
|
| 72 |
+
if not valid_phases2[s, j]:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# Extract data for phase j once
|
| 76 |
+
means2_j = means2_batch[s, j]
|
| 77 |
+
vars2_j = vars2_batch[s, j]
|
| 78 |
+
|
| 79 |
+
# VECTORIZED CORE COMPUTATION - "It's just math!"
|
| 80 |
+
# Calculate: sum((μ1 - μ2)² / (σ1² + σ2²))
|
| 81 |
+
|
| 82 |
+
# Step 1: Vector operations (no loops!)
|
| 83 |
+
diff = means1_i - means2_j # Vector subtraction
|
| 84 |
+
var_sum = vars1_i + vars2_j # Vector addition
|
| 85 |
+
|
| 86 |
+
# Step 2: Check variance validity (vectorized)
|
| 87 |
+
valid_variances = True
|
| 88 |
+
for f in range(n_features):
|
| 89 |
+
if var_sum[f] <= tol:
|
| 90 |
+
valid_variances = False
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
if valid_variances:
|
| 94 |
+
# Step 3: Quadratic form (vectorized)
|
| 95 |
+
quad_terms = diff * diff / var_sum # Element-wise operations
|
| 96 |
+
quad_sum = 0.0
|
| 97 |
+
for f in range(n_features): # Fast accumulation
|
| 98 |
+
quad_sum += quad_terms[f]
|
| 99 |
+
|
| 100 |
+
# Step 4: Exponential with underflow protection
|
| 101 |
+
half_quad = 0.5 * quad_sum
|
| 102 |
+
if half_quad <= 20.0:
|
| 103 |
+
overlap_batch[s, i, j] = np.exp(-half_quad)
|
| 104 |
+
|
| 105 |
+
return overlap_batch
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
|
| 109 |
+
def compute_overlap_batch_numba_ultra_fast(means1_batch, vars1_batch, means2_batch, vars2_batch):
|
| 110 |
+
"""
|
| 111 |
+
ULTRA-FAST MODE: Skip ALL validation for clean data.
|
| 112 |
+
|
| 113 |
+
This is the absolute fastest possible implementation - just pure math!
|
| 114 |
+
Use ONLY when you're certain the data has no NaN values.
|
| 115 |
+
|
| 116 |
+
Returns overlap = exp(-0.5 * sum((μ1 - μ2)² / (σ1² + σ2²)))
|
| 117 |
+
"""
|
| 118 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 119 |
+
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
|
| 120 |
+
|
| 121 |
+
for s in prange(n_subjects):
|
| 122 |
+
for i in range(150):
|
| 123 |
+
means1_i = means1_batch[s, i]
|
| 124 |
+
vars1_i = vars1_batch[s, i]
|
| 125 |
+
|
| 126 |
+
for j in range(150):
|
| 127 |
+
means2_j = means2_batch[s, j]
|
| 128 |
+
vars2_j = vars2_batch[s, j]
|
| 129 |
+
|
| 130 |
+
# Pure mathematical computation - no checks, no validation
|
| 131 |
+
diff = means1_i - means2_j
|
| 132 |
+
var_sum = vars1_i + vars2_j
|
| 133 |
+
quad_terms = diff * diff / var_sum
|
| 134 |
+
|
| 135 |
+
quad_sum = 0.0
|
| 136 |
+
for f in range(n_features):
|
| 137 |
+
quad_sum += quad_terms[f]
|
| 138 |
+
|
| 139 |
+
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
|
| 140 |
+
|
| 141 |
+
return overlap_batch
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
|
| 145 |
+
def compute_overlap_batch_numba_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
|
| 146 |
+
"""
|
| 147 |
+
VECTORIZED MODE: Enhanced Numba with better vectorization.
|
| 148 |
+
|
| 149 |
+
Processes entire rows at once to minimize inner loops and maximize cache efficiency.
|
| 150 |
+
This is the enhanced version that "throws more in" while staying on CPU.
|
| 151 |
+
"""
|
| 152 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 153 |
+
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
|
| 154 |
+
|
| 155 |
+
for s in prange(n_subjects):
|
| 156 |
+
# Process entire row at once for better vectorization
|
| 157 |
+
for i in range(150):
|
| 158 |
+
means1_i = means1_batch[s, i] # Shape: (n_features,)
|
| 159 |
+
vars1_i = vars1_batch[s, i]
|
| 160 |
+
|
| 161 |
+
# OPTIMIZATION: Vectorize the inner j loop by processing all j at once
|
| 162 |
+
# Create arrays for all phase_j comparisons
|
| 163 |
+
for j in range(150):
|
| 164 |
+
means2_j = means2_batch[s, j]
|
| 165 |
+
vars2_j = vars2_batch[s, j]
|
| 166 |
+
|
| 167 |
+
# Vectorized operations over features
|
| 168 |
+
diff = means1_i - means2_j
|
| 169 |
+
var_sum = vars1_i + vars2_j
|
| 170 |
+
quad_terms = diff * diff / var_sum
|
| 171 |
+
|
| 172 |
+
# Fast sum over features
|
| 173 |
+
quad_sum = 0.0
|
| 174 |
+
for f in range(n_features):
|
| 175 |
+
quad_sum += quad_terms[f]
|
| 176 |
+
|
| 177 |
+
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
|
| 178 |
+
|
| 179 |
+
return overlap_batch
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
|
| 183 |
+
def compute_overlap_batch_numba_row_vectorized(means1_batch, vars1_batch, means2_batch, vars2_batch):
|
| 184 |
+
"""
|
| 185 |
+
ROW-VECTORIZED MODE: Process entire rows of phase pairs at once.
|
| 186 |
+
|
| 187 |
+
This minimizes the innermost loops by computing all j phases for each i.
|
| 188 |
+
Better cache utilization and more vectorization opportunities.
|
| 189 |
+
"""
|
| 190 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 191 |
+
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
|
| 192 |
+
|
| 193 |
+
for s in prange(n_subjects):
|
| 194 |
+
for i in range(150):
|
| 195 |
+
means1_i = means1_batch[s, i] # Current phase means (n_features,)
|
| 196 |
+
vars1_i = vars1_batch[s, i] # Current phase variances
|
| 197 |
+
|
| 198 |
+
# Process all j phases for this i in one go
|
| 199 |
+
means2_all = means2_batch[s] # All phase means (150, n_features)
|
| 200 |
+
vars2_all = vars2_batch[s] # All phase variances
|
| 201 |
+
|
| 202 |
+
# Compute differences and sums for all j at once
|
| 203 |
+
for j in range(150):
|
| 204 |
+
# Fast vectorized computation over features
|
| 205 |
+
quad_sum = 0.0
|
| 206 |
+
for f in range(n_features):
|
| 207 |
+
diff_f = means1_i[f] - means2_all[j, f]
|
| 208 |
+
var_sum_f = vars1_i[f] + vars2_all[j, f]
|
| 209 |
+
quad_sum += diff_f * diff_f / var_sum_f
|
| 210 |
+
|
| 211 |
+
overlap_batch[s, i, j] = np.exp(-0.5 * quad_sum)
|
| 212 |
+
|
| 213 |
+
return overlap_batch
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@jit(nopython=True, cache=True)
|
| 217 |
+
def apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
|
| 218 |
+
means2_batch, vars2_batch, tol=1e-12):
|
| 219 |
+
"""
|
| 220 |
+
Apply biomechanical filtering in-place using Numba.
|
| 221 |
+
|
| 222 |
+
This modifies the overlap_batch array directly for maximum efficiency.
|
| 223 |
+
"""
|
| 224 |
+
n_subjects = overlap_batch.shape[0]
|
| 225 |
+
negligible_threshold = 0.1
|
| 226 |
+
ampable_threshold = 0.2
|
| 227 |
+
ci_factor = 1.96
|
| 228 |
+
|
| 229 |
+
for s in range(n_subjects):
|
| 230 |
+
# Only process first feature (torque) for biomechanical filtering
|
| 231 |
+
for i in range(150):
|
| 232 |
+
mean1 = means1_batch[s, i, 0]
|
| 233 |
+
var1 = vars1_batch[s, i, 0]
|
| 234 |
+
|
| 235 |
+
if np.isnan(mean1) or np.isnan(var1):
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
std1 = np.sqrt(var1)
|
| 239 |
+
ci_lo1 = mean1 - ci_factor * std1
|
| 240 |
+
ci_hi1 = mean1 + ci_factor * std1
|
| 241 |
+
|
| 242 |
+
negligible1 = (ci_lo1 >= -negligible_threshold) and (ci_hi1 <= negligible_threshold)
|
| 243 |
+
ampable1 = np.abs(mean1) > ampable_threshold
|
| 244 |
+
|
| 245 |
+
for j in range(150):
|
| 246 |
+
mean2 = means2_batch[s, j, 0]
|
| 247 |
+
var2 = vars2_batch[s, j, 0]
|
| 248 |
+
|
| 249 |
+
if np.isnan(mean2) or np.isnan(var2):
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
std2 = np.sqrt(var2)
|
| 253 |
+
ci_lo2 = mean2 - ci_factor * std2
|
| 254 |
+
ci_hi2 = mean2 + ci_factor * std2
|
| 255 |
+
|
| 256 |
+
negligible2 = (ci_lo2 >= -negligible_threshold) and (ci_hi2 <= negligible_threshold)
|
| 257 |
+
ampable2 = np.abs(mean2) > ampable_threshold
|
| 258 |
+
|
| 259 |
+
# Three-level filtering
|
| 260 |
+
if negligible1 and negligible2:
|
| 261 |
+
# Both negligible - set to 1
|
| 262 |
+
overlap_batch[s, i, j] = 1.0
|
| 263 |
+
elif (negligible1 and ampable2) or (negligible2 and ampable1):
|
| 264 |
+
# Amplitude conflict - keep original
|
| 265 |
+
pass
|
| 266 |
+
else:
|
| 267 |
+
# Sign reversal case - apply probability-based filtering
|
| 268 |
+
std1_safe = max(std1, tol)
|
| 269 |
+
std2_safe = max(std2, tol)
|
| 270 |
+
|
| 271 |
+
# Normal CDF approximation (simplified for Numba)
|
| 272 |
+
# Using a simple approximation since scipy.stats.norm is not available in nopython mode
|
| 273 |
+
z1 = mean1 / std1_safe
|
| 274 |
+
z2 = mean2 / std2_safe
|
| 275 |
+
|
| 276 |
+
# Simple normal CDF approximation
|
| 277 |
+
# This is less accurate but much faster and Numba-compatible
|
| 278 |
+
def norm_cdf_approx(x):
|
| 279 |
+
# Approximation of normal CDF
|
| 280 |
+
t = 1.0 / (1.0 + 0.2316419 * np.abs(x))
|
| 281 |
+
d = 0.3989423 * np.exp(-x * x / 2.0)
|
| 282 |
+
prob = d * t * (0.3193815 + t * (-0.3565638 + t * (1.781478 + t * (-1.821256 + t * 1.330274))))
|
| 283 |
+
if x > 0:
|
| 284 |
+
return 1.0 - prob
|
| 285 |
+
else:
|
| 286 |
+
return prob
|
| 287 |
+
|
| 288 |
+
Ppos1 = norm_cdf_approx(z1)
|
| 289 |
+
Ppos2 = norm_cdf_approx(z2)
|
| 290 |
+
|
| 291 |
+
# Sign-mismatch probability
|
| 292 |
+
Pdiff_sign = Ppos1 * (1.0 - Ppos2) + (1.0 - Ppos1) * Ppos2
|
| 293 |
+
|
| 294 |
+
# Mean-difference penalty
|
| 295 |
+
mean_diff = np.abs(mean1 - mean2)
|
| 296 |
+
s_thresh = 0.2
|
| 297 |
+
e_thresh = 0.5
|
| 298 |
+
|
| 299 |
+
if mean_diff <= s_thresh:
|
| 300 |
+
penalty = 0.0
|
| 301 |
+
elif mean_diff >= e_thresh:
|
| 302 |
+
penalty = 1.0
|
| 303 |
+
else:
|
| 304 |
+
# Linear ramp (simplified from sigmoid)
|
| 305 |
+
penalty = (mean_diff - s_thresh) / (e_thresh - s_thresh)
|
| 306 |
+
|
| 307 |
+
# Apply combined penalty
|
| 308 |
+
Pdiff = max(Pdiff_sign, penalty)
|
| 309 |
+
output_diff = 1.0 - overlap_batch[s, i, j]
|
| 310 |
+
overlap_batch[s, i, j] = 1.0 - output_diff * Pdiff
|
| 311 |
+
|
| 312 |
+
return overlap_batch
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def compute_overlap_batch_fallback(means1_batch, vars1_batch, means2_batch, vars2_batch, tol=1e-12):
|
| 316 |
+
"""
|
| 317 |
+
Fallback implementation when Numba is not available.
|
| 318 |
+
This is a simple, clean NumPy implementation without excessive overhead.
|
| 319 |
+
"""
|
| 320 |
+
n_subjects, n_phases, n_features = means1_batch.shape
|
| 321 |
+
overlap_batch = np.zeros((n_subjects, 150, 150), dtype=np.float64)
|
| 322 |
+
|
| 323 |
+
for s in range(n_subjects):
|
| 324 |
+
# Pre-compute NaN masks for this subject
|
| 325 |
+
has_nan1 = np.any(np.isnan(means1_batch[s]) | np.isnan(vars1_batch[s]), axis=1)
|
| 326 |
+
has_nan2 = np.any(np.isnan(means2_batch[s]) | np.isnan(vars2_batch[s]), axis=1)
|
| 327 |
+
|
| 328 |
+
for i in range(150):
|
| 329 |
+
if has_nan1[i]:
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
for j in range(150):
|
| 333 |
+
if has_nan2[j]:
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
# Direct computation
|
| 337 |
+
diff = means1_batch[s, i] - means2_batch[s, j]
|
| 338 |
+
sum_var = vars1_batch[s, i] + vars2_batch[s, j]
|
| 339 |
+
|
| 340 |
+
# Check validity
|
| 341 |
+
if np.all(sum_var > tol):
|
| 342 |
+
quad_sum = np.sum(diff**2 / sum_var)
|
| 343 |
+
half_quad = 0.5 * quad_sum
|
| 344 |
+
|
| 345 |
+
if half_quad <= 20.0:
|
| 346 |
+
overlap_batch[s, i, j] = np.exp(-half_quad)
|
| 347 |
+
|
| 348 |
+
return overlap_batch
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Main interface function
|
| 352 |
+
def compute_overlap_batch(means1_batch, vars1_batch, means2_batch, vars2_batch,
|
| 353 |
+
tol=1e-12, biomechanical_filter=False, ultra_fast=True,
|
| 354 |
+
vectorized_mode='auto'):
|
| 355 |
+
"""
|
| 356 |
+
Main interface for computing batch overlap with multiple vectorization modes.
|
| 357 |
+
|
| 358 |
+
Parameters:
|
| 359 |
+
ultra_fast: bool - Default True for maximum speed
|
| 360 |
+
vectorized_mode: str - 'auto', 'ultra_fast', 'vectorized', 'row_vectorized'
|
| 361 |
+
"""
|
| 362 |
+
if NUMBA_AVAILABLE:
|
| 363 |
+
# Select best vectorization strategy
|
| 364 |
+
if vectorized_mode == 'auto':
|
| 365 |
+
# Auto-select based on data size
|
| 366 |
+
n_subjects, _, n_features = means1_batch.shape
|
| 367 |
+
if n_features >= 10 or n_subjects >= 15:
|
| 368 |
+
mode = 'row_vectorized' # Best for larger feature sets
|
| 369 |
+
elif n_features >= 4:
|
| 370 |
+
mode = 'vectorized' # Good for medium feature sets
|
| 371 |
+
else:
|
| 372 |
+
mode = 'ultra_fast' # Simple and fast for small feature sets
|
| 373 |
+
else:
|
| 374 |
+
mode = vectorized_mode
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
if mode == 'row_vectorized':
|
| 378 |
+
overlap_batch = compute_overlap_batch_numba_row_vectorized(
|
| 379 |
+
means1_batch, vars1_batch, means2_batch, vars2_batch)
|
| 380 |
+
elif mode == 'vectorized':
|
| 381 |
+
overlap_batch = compute_overlap_batch_numba_vectorized(
|
| 382 |
+
means1_batch, vars1_batch, means2_batch, vars2_batch)
|
| 383 |
+
else: # ultra_fast
|
| 384 |
+
overlap_batch = compute_overlap_batch_numba_ultra_fast(
|
| 385 |
+
means1_batch, vars1_batch, means2_batch, vars2_batch)
|
| 386 |
+
except:
|
| 387 |
+
# Fallback to validated version if any optimized version fails
|
| 388 |
+
overlap_batch = compute_overlap_batch_numba(means1_batch, vars1_batch,
|
| 389 |
+
means2_batch, vars2_batch, tol)
|
| 390 |
+
|
| 391 |
+
if biomechanical_filter:
|
| 392 |
+
overlap_batch = apply_biomechanical_filter_numba(overlap_batch, means1_batch, vars1_batch,
|
| 393 |
+
means2_batch, vars2_batch, tol)
|
| 394 |
+
else:
|
| 395 |
+
overlap_batch = compute_overlap_batch_fallback(means1_batch, vars1_batch,
|
| 396 |
+
means2_batch, vars2_batch, tol)
|
| 397 |
+
|
| 398 |
+
# Final clipping
|
| 399 |
+
np.clip(overlap_batch, 0.0, 1.0, out=overlap_batch)
|
| 400 |
+
return overlap_batch
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == "__main__":
|
| 404 |
+
# Simple test to verify it works
|
| 405 |
+
print("Testing Numba overlap calculation...")
|
| 406 |
+
|
| 407 |
+
# Create test data
|
| 408 |
+
n_subjects = 10
|
| 409 |
+
n_features = 20 # e.g., 10 time windows × 2 sensors
|
| 410 |
+
|
| 411 |
+
means1 = np.random.randn(n_subjects, 150, n_features)
|
| 412 |
+
vars1 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
|
| 413 |
+
means2 = np.random.randn(n_subjects, 150, n_features)
|
| 414 |
+
vars2 = np.abs(np.random.randn(n_subjects, 150, n_features)) + 0.1
|
| 415 |
+
|
| 416 |
+
# Time the calculation
|
| 417 |
+
import time
|
| 418 |
+
|
| 419 |
+
print(f"Numba available: {NUMBA_AVAILABLE}")
|
| 420 |
+
print(f"Computing overlap for {n_subjects} subjects, {n_features} features...")
|
| 421 |
+
|
| 422 |
+
start = time.time()
|
| 423 |
+
result = compute_overlap_batch(means1, vars1, means2, vars2)
|
| 424 |
+
end = time.time()
|
| 425 |
+
|
| 426 |
+
print(f"Result shape: {result.shape}")
|
| 427 |
+
print(f"Execution time: {end - start:.3f} seconds")
|
| 428 |
+
print(f"Non-zero elements: {np.count_nonzero(result)}")
|
| 429 |
+
print(f"Max value: {np.max(result):.4f}")
|
| 430 |
+
print(f"Min value: {np.min(result):.4f}")
|
| 431 |
+
|
| 432 |
+
if NUMBA_AVAILABLE:
|
| 433 |
+
print("\n✅ Numba JIT compilation successful!")
|
| 434 |
+
print("The first run compiles the function, subsequent runs will be much faster.")
|
| 435 |
+
|
| 436 |
+
# Run again to show compiled performance
|
| 437 |
+
start = time.time()
|
| 438 |
+
result = compute_overlap_batch(means1, vars1, means2, vars2)
|
| 439 |
+
end = time.time()
|
| 440 |
+
print(f"Compiled execution time: {end - start:.3f} seconds")
|
| 441 |
+
else:
|
| 442 |
+
print("\n⚠️ Numba not available, using fallback implementation.")
|
| 443 |
+
print("Install Numba for 10-100x speedup: pip install numba")
|
pages/02_Tool.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import seaborn as sns
|
| 9 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 10 |
-
from multivariate_gaussian_overlap import
|
| 11 |
from plot_similarity import plot_similarity_measure
|
| 12 |
from sensor_illustration import LegIllustration
|
| 13 |
from shared_styling import set_plot_style, apply_cream_theme_to_figure
|
|
@@ -376,19 +376,32 @@ def calculate_overlap_measures(task1, task2, sensors, abstraction_level,
|
|
| 376 |
|
| 377 |
|
| 378 |
# Get the phase windows from session state or use default
|
| 379 |
-
if
|
| 380 |
time_windows = [1]
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
# Clear the progress indicators
|
| 394 |
progress_placeholder.empty()
|
|
@@ -495,7 +508,7 @@ if st.session_state.authenticated:
|
|
| 495 |
task2_name_backend = (task2, task2_incline, task2_speed)
|
| 496 |
|
| 497 |
# This will be passed to calculate_overlap_measures, which then passes its .lower() version
|
| 498 |
-
# to
|
| 499 |
abstraction_level_backend = analysis_detail
|
| 500 |
|
| 501 |
progress_placeholder = st.empty()
|
|
@@ -682,4 +695,4 @@ if st.session_state.authenticated:
|
|
| 682 |
st.markdown("--- Marginal Distributions ---")
|
| 683 |
# Placeholder: Add logic to plot marginal distributions if required.
|
| 684 |
# This was previously outside the main conditional blocks.
|
| 685 |
-
st.write("(Marginal distribution plotting not implemented in this layout yet)")
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import seaborn as sns
|
| 9 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 10 |
+
from multivariate_gaussian_overlap import calculate_portrait
|
| 11 |
from plot_similarity import plot_similarity_measure
|
| 12 |
from sensor_illustration import LegIllustration
|
| 13 |
from shared_styling import set_plot_style, apply_cream_theme_to_figure
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
# Get the phase windows from session state or use default
|
| 379 |
+
if not time_windows:
|
| 380 |
time_windows = [1]
|
| 381 |
+
|
| 382 |
+
# Configure inputs for new calculate_portrait API
|
| 383 |
+
portrait_kwargs = {
|
| 384 |
+
'task1': task1_spec if any(task1_spec) else None,
|
| 385 |
+
'task2': task2_spec if any(task2_spec) else None,
|
| 386 |
+
'time_window': time_windows,
|
| 387 |
+
'progress_callback': lambda x: progress_bar.progress(x),
|
| 388 |
+
'method': 'auto'
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
if use_output_data:
|
| 392 |
+
portrait_kwargs.update({
|
| 393 |
+
'sensors': None,
|
| 394 |
+
'output_sensors': list(sensors),
|
| 395 |
+
'biomechanical_filter': biomechanical_filtering
|
| 396 |
+
})
|
| 397 |
+
else:
|
| 398 |
+
portrait_kwargs.update({
|
| 399 |
+
'sensors': list(sensors),
|
| 400 |
+
'output_sensors': None,
|
| 401 |
+
'biomechanical_filter': False
|
| 402 |
+
})
|
| 403 |
+
|
| 404 |
+
similarity = calculate_portrait(**portrait_kwargs)
|
| 405 |
|
| 406 |
# Clear the progress indicators
|
| 407 |
progress_placeholder.empty()
|
|
|
|
| 508 |
task2_name_backend = (task2, task2_incline, task2_speed)
|
| 509 |
|
| 510 |
# This will be passed to calculate_overlap_measures, which then passes its .lower() version
|
| 511 |
+
# to calculate_portrait.
|
| 512 |
abstraction_level_backend = analysis_detail
|
| 513 |
|
| 514 |
progress_placeholder = st.empty()
|
|
|
|
| 695 |
st.markdown("--- Marginal Distributions ---")
|
| 696 |
# Placeholder: Add logic to plot marginal distributions if required.
|
| 697 |
# This was previously outside the main conditional blocks.
|
| 698 |
+
st.write("(Marginal distribution plotting not implemented in this layout yet)")
|
plot_similarity.py
CHANGED
|
@@ -17,7 +17,6 @@ def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
|
|
| 17 |
fontsize: int = 16,
|
| 18 |
y_label_pad: int = 20,
|
| 19 |
cbar_labels: bool = True,
|
| 20 |
-
cutoff_treshold:float=None,
|
| 21 |
high_level_plot: bool = False):
|
| 22 |
"""Plot similarity measure with consistent styling.
|
| 23 |
|
|
@@ -110,36 +109,13 @@ def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
|
|
| 110 |
else:
|
| 111 |
cbar_obj.set_ticks([])
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# The actual cutoff will depend on the plot type
|
| 117 |
-
if plot_type == 'input':
|
| 118 |
-
true_cutoff = cutoff_treshold
|
| 119 |
-
elif plot_type == 'output':
|
| 120 |
-
true_cutoff = 1 - cutoff_treshold
|
| 121 |
-
elif plot_type == 'output_biomechanical':
|
| 122 |
-
true_cutoff = 1 - cutoff_treshold
|
| 123 |
-
else: # conflict
|
| 124 |
-
true_cutoff = cutoff_treshold * (1 - cutoff_treshold)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# Count percent of values above threshold
|
| 128 |
-
if high_level_plot:
|
| 129 |
-
# Subtract diagonal from thresholded values
|
| 130 |
-
mask = np.ones_like(measure_data, dtype=bool)
|
| 131 |
-
np.fill_diagonal(mask, False)
|
| 132 |
-
total = np.sum(mask)
|
| 133 |
-
above = np.sum((measure_data > true_cutoff) & mask)
|
| 134 |
-
else:
|
| 135 |
-
total = measure_data.size
|
| 136 |
-
above = np.sum(measure_data > true_cutoff)
|
| 137 |
-
percent = 100.0 * above / total if total > 0 else 0.0
|
| 138 |
# Format as e.g. "12.3% > 0.8"
|
| 139 |
-
annotation = f"{percent:.1f}%
|
| 140 |
# Place annotation to the right of the colorbar
|
| 141 |
cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
|
| 142 |
-
|
| 143 |
|
| 144 |
# Set aspect ratio to equal
|
| 145 |
ax.set_aspect('equal')
|
|
|
|
| 17 |
fontsize: int = 16,
|
| 18 |
y_label_pad: int = 20,
|
| 19 |
cbar_labels: bool = True,
|
|
|
|
| 20 |
high_level_plot: bool = False):
|
| 21 |
"""Plot similarity measure with consistent styling.
|
| 22 |
|
|
|
|
| 109 |
else:
|
| 110 |
cbar_obj.set_ticks([])
|
| 111 |
|
| 112 |
+
if plot_type != 'output':
|
| 113 |
+
percent = np.mean(measure_data) * 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
# Format as e.g. "12.3% > 0.8"
|
| 115 |
+
annotation = r"$\tilde C_{total} = $" + f"{percent:.1f}%"
|
| 116 |
# Place annotation to the right of the colorbar
|
| 117 |
cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
|
| 118 |
+
fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
|
| 119 |
|
| 120 |
# Set aspect ratio to equal
|
| 121 |
ax.set_aspect('equal')
|
plot_styling.py
CHANGED
|
@@ -1,13 +1,59 @@
|
|
| 1 |
"""
|
| 2 |
Unified styling module for both Streamlit UI and matplotlib plots.
|
| 3 |
Contains all styling definitions to ensure consistency across the application.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import seaborn as sns
|
| 9 |
import matplotlib.font_manager as fm
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# ==========================
|
| 12 |
# Shared Color Themes
|
| 13 |
# ==========================
|
|
@@ -46,6 +92,40 @@ LIGHT_COLORS = {
|
|
| 46 |
'spine_color': '#E0E0E0',
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Dark theme colors - consistent across UI and plots
|
| 50 |
DARK_COLORS = {
|
| 51 |
'background': '#1E1E1E',
|
|
@@ -80,8 +160,12 @@ DARK_COLORS = {
|
|
| 80 |
'spine_color': '#505050',
|
| 81 |
}
|
| 82 |
|
|
|
|
| 83 |
def get_current_colors():
|
| 84 |
"""Get the current color scheme based on session state."""
|
|
|
|
|
|
|
|
|
|
| 85 |
if 'dark_theme' not in st.session_state:
|
| 86 |
st.session_state.dark_theme = False
|
| 87 |
|
|
@@ -89,12 +173,18 @@ def get_current_colors():
|
|
| 89 |
|
| 90 |
def toggle_theme():
|
| 91 |
"""Toggle between light and dark themes."""
|
|
|
|
|
|
|
|
|
|
| 92 |
if 'dark_theme' not in st.session_state:
|
| 93 |
st.session_state.dark_theme = False
|
| 94 |
st.session_state.dark_theme = not st.session_state.dark_theme
|
| 95 |
|
| 96 |
def add_theme_toggle():
|
| 97 |
"""Add a theme toggle button to the sidebar."""
|
|
|
|
|
|
|
|
|
|
| 98 |
with st.sidebar:
|
| 99 |
st.markdown("---")
|
| 100 |
current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
|
|
@@ -141,9 +231,18 @@ purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
|
|
| 141 |
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
|
| 142 |
reverse=True, as_cmap=True)
|
| 143 |
|
| 144 |
-
def get_plot_style():
|
| 145 |
-
"""Get plot style with
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
return {
|
| 149 |
'font_family': PLOT_STYLE_FONT_FAMILY,
|
|
@@ -168,9 +267,13 @@ def get_plot_style():
|
|
| 168 |
'spine_color': theme_colors['spine_color'],
|
| 169 |
}
|
| 170 |
|
| 171 |
-
def set_plot_style():
|
| 172 |
-
"""Set consistent plot styling across all figures
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
plt.rcParams['font.family'] = plot_style['font_family']
|
| 176 |
plt.rcParams['font.size'] = plot_style['font_size']
|
|
@@ -233,6 +336,10 @@ def apply_theme_to_figure(fig, ax=None):
|
|
| 233 |
|
| 234 |
return fig, ax
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
# Legacy function name for backward compatibility
|
| 237 |
def apply_cream_theme_to_figure(fig, ax=None):
|
| 238 |
"""Apply current theme to an existing figure and axes (legacy function name)"""
|
|
@@ -461,19 +568,27 @@ def get_tool_page_css():
|
|
| 461 |
|
| 462 |
def apply_base_styling():
|
| 463 |
"""Apply the base styling to the current Streamlit page."""
|
|
|
|
|
|
|
| 464 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 465 |
|
| 466 |
def apply_home_page_styling():
|
| 467 |
"""Apply styling specific to the home page."""
|
|
|
|
|
|
|
| 468 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 469 |
st.markdown(get_home_page_css(), unsafe_allow_html=True)
|
| 470 |
|
| 471 |
def apply_documentation_page_styling():
|
| 472 |
"""Apply styling specific to the documentation page."""
|
|
|
|
|
|
|
| 473 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 474 |
st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
|
| 475 |
|
| 476 |
def apply_tool_page_styling():
|
| 477 |
"""Apply styling specific to the analysis tool page."""
|
|
|
|
|
|
|
| 478 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 479 |
st.markdown(get_tool_page_css(), unsafe_allow_html=True)
|
|
|
|
| 1 |
"""
|
| 2 |
Unified styling module for both Streamlit UI and matplotlib plots.
|
| 3 |
Contains all styling definitions to ensure consistency across the application.
|
| 4 |
+
|
| 5 |
+
Note: When used outside of Streamlit environment (e.g., in Jupyter notebooks),
|
| 6 |
+
you may see warnings about missing ScriptRunContext or Session state. These
|
| 7 |
+
warnings are harmless and can be safely ignored - the core plotting functions
|
| 8 |
+
(get_plot_style, set_plot_style, PLOT_COLORS) work correctly regardless.
|
| 9 |
"""
|
| 10 |
|
| 11 |
+
import warnings
|
| 12 |
+
import logging
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import seaborn as sns
|
| 15 |
import matplotlib.font_manager as fm
|
| 16 |
|
| 17 |
+
# Suppress Streamlit warnings when running outside streamlit environment
|
| 18 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='streamlit')
|
| 19 |
+
warnings.filterwarnings('ignore', message='.*ScriptRunContext.*')
|
| 20 |
+
warnings.filterwarnings('ignore', message='.*Session state.*')
|
| 21 |
+
warnings.filterwarnings('ignore', message='.*missing ScriptRunContext.*')
|
| 22 |
+
warnings.filterwarnings('ignore', message='.*does not function when running.*')
|
| 23 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.*')
|
| 24 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.scriptrunner_utils.*')
|
| 25 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.state.*')
|
| 26 |
+
|
| 27 |
+
# Suppress Streamlit loggers that generate warnings outside streamlit environment
|
| 28 |
+
logging.getLogger('streamlit.runtime.scriptrunner_utils.script_run_context').setLevel(logging.ERROR)
|
| 29 |
+
logging.getLogger('streamlit.runtime.state.session_state_proxy').setLevel(logging.ERROR)
|
| 30 |
+
logging.getLogger('streamlit').setLevel(logging.ERROR)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# Set logging level before importing to suppress initial warnings
|
| 34 |
+
for logger_name in ['streamlit', 'streamlit.runtime', 'streamlit.runtime.scriptrunner_utils',
|
| 35 |
+
'streamlit.runtime.state', 'streamlit.runtime.scriptrunner_utils.script_run_context',
|
| 36 |
+
'streamlit.runtime.state.session_state_proxy']:
|
| 37 |
+
logging.getLogger(logger_name).setLevel(logging.ERROR)
|
| 38 |
+
|
| 39 |
+
import streamlit as st
|
| 40 |
+
_STREAMLIT_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
_STREAMLIT_AVAILABLE = False
|
| 43 |
+
# Create a mock streamlit module for non-streamlit environments
|
| 44 |
+
class MockStreamlit:
|
| 45 |
+
class session_state:
|
| 46 |
+
dark_theme = False
|
| 47 |
+
st = MockStreamlit()
|
| 48 |
+
|
| 49 |
+
def _suppress_streamlit_warnings(func):
|
| 50 |
+
"""Decorator to suppress streamlit warnings in functions."""
|
| 51 |
+
def wrapper(*args, **kwargs):
|
| 52 |
+
with warnings.catch_warnings():
|
| 53 |
+
warnings.simplefilter('ignore')
|
| 54 |
+
return func(*args, **kwargs)
|
| 55 |
+
return wrapper
|
| 56 |
+
|
| 57 |
# ==========================
|
| 58 |
# Shared Color Themes
|
| 59 |
# ==========================
|
|
|
|
| 92 |
'spine_color': '#E0E0E0',
|
| 93 |
}
|
| 94 |
|
| 95 |
+
# Paper theme colors - pure white backgrounds for publication
|
| 96 |
+
PAPER_COLORS = {
|
| 97 |
+
'background': '#FFFFFF',
|
| 98 |
+
'figure_background': '#FFFFFF',
|
| 99 |
+
'sidebar_bg_start': '#FAFAFA',
|
| 100 |
+
'sidebar_bg_end': '#F5F5F5',
|
| 101 |
+
'border_light': '#F0F0F0',
|
| 102 |
+
'border_medium': '#E0E0E0',
|
| 103 |
+
'text_primary': '#2C3E50',
|
| 104 |
+
'text_secondary': '#5D6D7E',
|
| 105 |
+
'text_tertiary': '#85929E',
|
| 106 |
+
'text_light': '#A6ACAF',
|
| 107 |
+
'button_bg_start': '#5D6D7E',
|
| 108 |
+
'button_bg_end': '#85929E',
|
| 109 |
+
'button_hover_start': '#85929E',
|
| 110 |
+
'button_hover_end': '#A6ACAF',
|
| 111 |
+
'alert_error_bg': '#FFE6E6',
|
| 112 |
+
'alert_error_border': '#FFAAAA',
|
| 113 |
+
'alert_error_text': '#CC0000',
|
| 114 |
+
'alert_info_bg': '#E6F3FF',
|
| 115 |
+
'alert_info_border': '#99D6FF',
|
| 116 |
+
'alert_info_text': '#0066CC',
|
| 117 |
+
'warning_bg': '#FFF8E1',
|
| 118 |
+
'warning_border': '#FF9800',
|
| 119 |
+
'success_bg': '#E8F5E8',
|
| 120 |
+
'success_border': '#4CAF50',
|
| 121 |
+
'generate_button_bg': '#228B22',
|
| 122 |
+
'generate_button_hover': '#32CD32',
|
| 123 |
+
# Plot-specific colors - pure white for papers
|
| 124 |
+
'axes_background': '#FFFFFF',
|
| 125 |
+
'grid_color': '#F0F0F0',
|
| 126 |
+
'spine_color': '#E0E0E0',
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
# Dark theme colors - consistent across UI and plots
|
| 130 |
DARK_COLORS = {
|
| 131 |
'background': '#1E1E1E',
|
|
|
|
| 160 |
'spine_color': '#505050',
|
| 161 |
}
|
| 162 |
|
| 163 |
+
@_suppress_streamlit_warnings
|
| 164 |
def get_current_colors():
|
| 165 |
"""Get the current color scheme based on session state."""
|
| 166 |
+
if not _STREAMLIT_AVAILABLE:
|
| 167 |
+
return LIGHT_COLORS
|
| 168 |
+
|
| 169 |
if 'dark_theme' not in st.session_state:
|
| 170 |
st.session_state.dark_theme = False
|
| 171 |
|
|
|
|
| 173 |
|
| 174 |
def toggle_theme():
|
| 175 |
"""Toggle between light and dark themes."""
|
| 176 |
+
if not _STREAMLIT_AVAILABLE:
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
if 'dark_theme' not in st.session_state:
|
| 180 |
st.session_state.dark_theme = False
|
| 181 |
st.session_state.dark_theme = not st.session_state.dark_theme
|
| 182 |
|
| 183 |
def add_theme_toggle():
|
| 184 |
"""Add a theme toggle button to the sidebar."""
|
| 185 |
+
if not _STREAMLIT_AVAILABLE:
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
with st.sidebar:
|
| 189 |
st.markdown("---")
|
| 190 |
current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
|
|
|
|
| 231 |
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
|
| 232 |
reverse=True, as_cmap=True)
|
| 233 |
|
| 234 |
+
def get_plot_style(style='default'):
|
| 235 |
+
"""Get plot style with specified color theme.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
|
| 239 |
+
"""
|
| 240 |
+
if style == 'paper':
|
| 241 |
+
theme_colors = PAPER_COLORS
|
| 242 |
+
elif style == 'dark':
|
| 243 |
+
theme_colors = DARK_COLORS
|
| 244 |
+
else: # default
|
| 245 |
+
theme_colors = get_current_colors()
|
| 246 |
|
| 247 |
return {
|
| 248 |
'font_family': PLOT_STYLE_FONT_FAMILY,
|
|
|
|
| 267 |
'spine_color': theme_colors['spine_color'],
|
| 268 |
}
|
| 269 |
|
| 270 |
+
def set_plot_style(style='default'):
|
| 271 |
+
"""Set consistent plot styling across all figures.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
|
| 275 |
+
"""
|
| 276 |
+
plot_style = get_plot_style(style=style)
|
| 277 |
|
| 278 |
plt.rcParams['font.family'] = plot_style['font_family']
|
| 279 |
plt.rcParams['font.size'] = plot_style['font_size']
|
|
|
|
| 336 |
|
| 337 |
return fig, ax
|
| 338 |
|
| 339 |
+
def set_paper_plot_style():
|
| 340 |
+
"""Convenience function to set pure white backgrounds for paper publication."""
|
| 341 |
+
set_plot_style(style='paper')
|
| 342 |
+
|
| 343 |
# Legacy function name for backward compatibility
|
| 344 |
def apply_cream_theme_to_figure(fig, ax=None):
|
| 345 |
"""Apply current theme to an existing figure and axes (legacy function name)"""
|
|
|
|
| 568 |
|
| 569 |
def apply_base_styling():
|
| 570 |
"""Apply the base styling to the current Streamlit page."""
|
| 571 |
+
if not _STREAMLIT_AVAILABLE:
|
| 572 |
+
return
|
| 573 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 574 |
|
| 575 |
def apply_home_page_styling():
|
| 576 |
"""Apply styling specific to the home page."""
|
| 577 |
+
if not _STREAMLIT_AVAILABLE:
|
| 578 |
+
return
|
| 579 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 580 |
st.markdown(get_home_page_css(), unsafe_allow_html=True)
|
| 581 |
|
| 582 |
def apply_documentation_page_styling():
|
| 583 |
"""Apply styling specific to the documentation page."""
|
| 584 |
+
if not _STREAMLIT_AVAILABLE:
|
| 585 |
+
return
|
| 586 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 587 |
st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
|
| 588 |
|
| 589 |
def apply_tool_page_styling():
|
| 590 |
"""Apply styling specific to the analysis tool page."""
|
| 591 |
+
if not _STREAMLIT_AVAILABLE:
|
| 592 |
+
return
|
| 593 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 594 |
st.markdown(get_tool_page_css(), unsafe_allow_html=True)
|
requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ numpy>=1.21.0
|
|
| 3 |
matplotlib>=3.4.0
|
| 4 |
seaborn>=0.11.0
|
| 5 |
pandas>=1.3.0
|
| 6 |
-
scipy>=1.7.0
|
|
|
|
|
|
| 3 |
matplotlib>=3.4.0
|
| 4 |
seaborn>=0.11.0
|
| 5 |
pandas>=1.3.0
|
| 6 |
+
scipy>=1.7.0
|
| 7 |
+
tqdm>=4.64.0
|
shared_styling.py
CHANGED
|
@@ -1,13 +1,59 @@
|
|
| 1 |
"""
|
| 2 |
Unified styling module for both Streamlit UI and matplotlib plots.
|
| 3 |
Contains all styling definitions to ensure consistency across the application.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import seaborn as sns
|
| 9 |
import matplotlib.font_manager as fm
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# ==========================
|
| 12 |
# Shared Color Themes
|
| 13 |
# ==========================
|
|
@@ -46,6 +92,40 @@ LIGHT_COLORS = {
|
|
| 46 |
'spine_color': '#E0E0E0',
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Dark theme colors - consistent across UI and plots
|
| 50 |
DARK_COLORS = {
|
| 51 |
'background': '#1E1E1E',
|
|
@@ -80,8 +160,12 @@ DARK_COLORS = {
|
|
| 80 |
'spine_color': '#505050',
|
| 81 |
}
|
| 82 |
|
|
|
|
| 83 |
def get_current_colors():
|
| 84 |
"""Get the current color scheme based on session state."""
|
|
|
|
|
|
|
|
|
|
| 85 |
if 'dark_theme' not in st.session_state:
|
| 86 |
st.session_state.dark_theme = False
|
| 87 |
|
|
@@ -89,12 +173,18 @@ def get_current_colors():
|
|
| 89 |
|
| 90 |
def toggle_theme():
|
| 91 |
"""Toggle between light and dark themes."""
|
|
|
|
|
|
|
|
|
|
| 92 |
if 'dark_theme' not in st.session_state:
|
| 93 |
st.session_state.dark_theme = False
|
| 94 |
st.session_state.dark_theme = not st.session_state.dark_theme
|
| 95 |
|
| 96 |
def add_theme_toggle():
|
| 97 |
"""Add a theme toggle button to the sidebar."""
|
|
|
|
|
|
|
|
|
|
| 98 |
with st.sidebar:
|
| 99 |
st.markdown("---")
|
| 100 |
current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
|
|
@@ -141,9 +231,18 @@ purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
|
|
| 141 |
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
|
| 142 |
reverse=True, as_cmap=True)
|
| 143 |
|
| 144 |
-
def get_plot_style():
|
| 145 |
-
"""Get plot style with
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
return {
|
| 149 |
'font_family': PLOT_STYLE_FONT_FAMILY,
|
|
@@ -168,9 +267,13 @@ def get_plot_style():
|
|
| 168 |
'spine_color': theme_colors['spine_color'],
|
| 169 |
}
|
| 170 |
|
| 171 |
-
def set_plot_style():
|
| 172 |
-
"""Set consistent plot styling across all figures
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
plt.rcParams['font.family'] = plot_style['font_family']
|
| 176 |
plt.rcParams['font.size'] = plot_style['font_size']
|
|
@@ -233,6 +336,10 @@ def apply_theme_to_figure(fig, ax=None):
|
|
| 233 |
|
| 234 |
return fig, ax
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
# Legacy function name for backward compatibility
|
| 237 |
def apply_cream_theme_to_figure(fig, ax=None):
|
| 238 |
"""Apply current theme to an existing figure and axes (legacy function name)"""
|
|
@@ -461,19 +568,27 @@ def get_tool_page_css():
|
|
| 461 |
|
| 462 |
def apply_base_styling():
|
| 463 |
"""Apply the base styling to the current Streamlit page."""
|
|
|
|
|
|
|
| 464 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 465 |
|
| 466 |
def apply_home_page_styling():
|
| 467 |
"""Apply styling specific to the home page."""
|
|
|
|
|
|
|
| 468 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 469 |
st.markdown(get_home_page_css(), unsafe_allow_html=True)
|
| 470 |
|
| 471 |
def apply_documentation_page_styling():
|
| 472 |
"""Apply styling specific to the documentation page."""
|
|
|
|
|
|
|
| 473 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 474 |
st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
|
| 475 |
|
| 476 |
def apply_tool_page_styling():
|
| 477 |
"""Apply styling specific to the analysis tool page."""
|
|
|
|
|
|
|
| 478 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 479 |
st.markdown(get_tool_page_css(), unsafe_allow_html=True)
|
|
|
|
| 1 |
"""
|
| 2 |
Unified styling module for both Streamlit UI and matplotlib plots.
|
| 3 |
Contains all styling definitions to ensure consistency across the application.
|
| 4 |
+
|
| 5 |
+
Note: When used outside of Streamlit environment (e.g., in Jupyter notebooks),
|
| 6 |
+
you may see warnings about missing ScriptRunContext or Session state. These
|
| 7 |
+
warnings are harmless and can be safely ignored - the core plotting functions
|
| 8 |
+
(get_plot_style, set_plot_style, PLOT_COLORS) work correctly regardless.
|
| 9 |
"""
|
| 10 |
|
| 11 |
+
import warnings
|
| 12 |
+
import logging
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
import seaborn as sns
|
| 15 |
import matplotlib.font_manager as fm
|
| 16 |
|
| 17 |
+
# Suppress Streamlit warnings when running outside streamlit environment
|
| 18 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='streamlit')
|
| 19 |
+
warnings.filterwarnings('ignore', message='.*ScriptRunContext.*')
|
| 20 |
+
warnings.filterwarnings('ignore', message='.*Session state.*')
|
| 21 |
+
warnings.filterwarnings('ignore', message='.*missing ScriptRunContext.*')
|
| 22 |
+
warnings.filterwarnings('ignore', message='.*does not function when running.*')
|
| 23 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.*')
|
| 24 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.scriptrunner_utils.*')
|
| 25 |
+
warnings.filterwarnings('ignore', module='streamlit.runtime.state.*')
|
| 26 |
+
|
| 27 |
+
# Suppress Streamlit loggers that generate warnings outside streamlit environment
|
| 28 |
+
logging.getLogger('streamlit.runtime.scriptrunner_utils.script_run_context').setLevel(logging.ERROR)
|
| 29 |
+
logging.getLogger('streamlit.runtime.state.session_state_proxy').setLevel(logging.ERROR)
|
| 30 |
+
logging.getLogger('streamlit').setLevel(logging.ERROR)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
# Set logging level before importing to suppress initial warnings
|
| 34 |
+
for logger_name in ['streamlit', 'streamlit.runtime', 'streamlit.runtime.scriptrunner_utils',
|
| 35 |
+
'streamlit.runtime.state', 'streamlit.runtime.scriptrunner_utils.script_run_context',
|
| 36 |
+
'streamlit.runtime.state.session_state_proxy']:
|
| 37 |
+
logging.getLogger(logger_name).setLevel(logging.ERROR)
|
| 38 |
+
|
| 39 |
+
import streamlit as st
|
| 40 |
+
_STREAMLIT_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
_STREAMLIT_AVAILABLE = False
|
| 43 |
+
# Create a mock streamlit module for non-streamlit environments
|
| 44 |
+
class MockStreamlit:
|
| 45 |
+
class session_state:
|
| 46 |
+
dark_theme = False
|
| 47 |
+
st = MockStreamlit()
|
| 48 |
+
|
| 49 |
+
def _suppress_streamlit_warnings(func):
|
| 50 |
+
"""Decorator to suppress streamlit warnings in functions."""
|
| 51 |
+
def wrapper(*args, **kwargs):
|
| 52 |
+
with warnings.catch_warnings():
|
| 53 |
+
warnings.simplefilter('ignore')
|
| 54 |
+
return func(*args, **kwargs)
|
| 55 |
+
return wrapper
|
| 56 |
+
|
| 57 |
# ==========================
|
| 58 |
# Shared Color Themes
|
| 59 |
# ==========================
|
|
|
|
| 92 |
'spine_color': '#E0E0E0',
|
| 93 |
}
|
| 94 |
|
| 95 |
+
# Paper theme colors - pure white backgrounds for publication
|
| 96 |
+
PAPER_COLORS = {
|
| 97 |
+
'background': '#FFFFFF',
|
| 98 |
+
'figure_background': '#FFFFFF',
|
| 99 |
+
'sidebar_bg_start': '#FAFAFA',
|
| 100 |
+
'sidebar_bg_end': '#F5F5F5',
|
| 101 |
+
'border_light': '#F0F0F0',
|
| 102 |
+
'border_medium': '#E0E0E0',
|
| 103 |
+
'text_primary': '#2C3E50',
|
| 104 |
+
'text_secondary': '#5D6D7E',
|
| 105 |
+
'text_tertiary': '#85929E',
|
| 106 |
+
'text_light': '#A6ACAF',
|
| 107 |
+
'button_bg_start': '#5D6D7E',
|
| 108 |
+
'button_bg_end': '#85929E',
|
| 109 |
+
'button_hover_start': '#85929E',
|
| 110 |
+
'button_hover_end': '#A6ACAF',
|
| 111 |
+
'alert_error_bg': '#FFE6E6',
|
| 112 |
+
'alert_error_border': '#FFAAAA',
|
| 113 |
+
'alert_error_text': '#CC0000',
|
| 114 |
+
'alert_info_bg': '#E6F3FF',
|
| 115 |
+
'alert_info_border': '#99D6FF',
|
| 116 |
+
'alert_info_text': '#0066CC',
|
| 117 |
+
'warning_bg': '#FFF8E1',
|
| 118 |
+
'warning_border': '#FF9800',
|
| 119 |
+
'success_bg': '#E8F5E8',
|
| 120 |
+
'success_border': '#4CAF50',
|
| 121 |
+
'generate_button_bg': '#228B22',
|
| 122 |
+
'generate_button_hover': '#32CD32',
|
| 123 |
+
# Plot-specific colors - pure white for papers
|
| 124 |
+
'axes_background': '#FFFFFF',
|
| 125 |
+
'grid_color': '#F0F0F0',
|
| 126 |
+
'spine_color': '#E0E0E0',
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
# Dark theme colors - consistent across UI and plots
|
| 130 |
DARK_COLORS = {
|
| 131 |
'background': '#1E1E1E',
|
|
|
|
| 160 |
'spine_color': '#505050',
|
| 161 |
}
|
| 162 |
|
| 163 |
+
@_suppress_streamlit_warnings
|
| 164 |
def get_current_colors():
|
| 165 |
"""Get the current color scheme based on session state."""
|
| 166 |
+
if not _STREAMLIT_AVAILABLE:
|
| 167 |
+
return LIGHT_COLORS
|
| 168 |
+
|
| 169 |
if 'dark_theme' not in st.session_state:
|
| 170 |
st.session_state.dark_theme = False
|
| 171 |
|
|
|
|
| 173 |
|
| 174 |
def toggle_theme():
|
| 175 |
"""Toggle between light and dark themes."""
|
| 176 |
+
if not _STREAMLIT_AVAILABLE:
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
if 'dark_theme' not in st.session_state:
|
| 180 |
st.session_state.dark_theme = False
|
| 181 |
st.session_state.dark_theme = not st.session_state.dark_theme
|
| 182 |
|
| 183 |
def add_theme_toggle():
|
| 184 |
"""Add a theme toggle button to the sidebar."""
|
| 185 |
+
if not _STREAMLIT_AVAILABLE:
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
with st.sidebar:
|
| 189 |
st.markdown("---")
|
| 190 |
current_theme = "🌙 Dark" if st.session_state.get('dark_theme', False) else "☀️ Light"
|
|
|
|
| 231 |
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
|
| 232 |
reverse=True, as_cmap=True)
|
| 233 |
|
| 234 |
+
def get_plot_style(style='default'):
|
| 235 |
+
"""Get plot style with specified color theme.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
|
| 239 |
+
"""
|
| 240 |
+
if style == 'paper':
|
| 241 |
+
theme_colors = PAPER_COLORS
|
| 242 |
+
elif style == 'dark':
|
| 243 |
+
theme_colors = DARK_COLORS
|
| 244 |
+
else: # default
|
| 245 |
+
theme_colors = get_current_colors()
|
| 246 |
|
| 247 |
return {
|
| 248 |
'font_family': PLOT_STYLE_FONT_FAMILY,
|
|
|
|
| 267 |
'spine_color': theme_colors['spine_color'],
|
| 268 |
}
|
| 269 |
|
| 270 |
+
def set_plot_style(style='default'):
|
| 271 |
+
"""Set consistent plot styling across all figures.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
style: 'default' for cream theme, 'paper' for pure white backgrounds, 'dark' for dark theme
|
| 275 |
+
"""
|
| 276 |
+
plot_style = get_plot_style(style=style)
|
| 277 |
|
| 278 |
plt.rcParams['font.family'] = plot_style['font_family']
|
| 279 |
plt.rcParams['font.size'] = plot_style['font_size']
|
|
|
|
| 336 |
|
| 337 |
return fig, ax
|
| 338 |
|
| 339 |
+
def set_paper_plot_style():
|
| 340 |
+
"""Convenience function to set pure white backgrounds for paper publication."""
|
| 341 |
+
set_plot_style(style='paper')
|
| 342 |
+
|
| 343 |
# Legacy function name for backward compatibility
|
| 344 |
def apply_cream_theme_to_figure(fig, ax=None):
|
| 345 |
"""Apply current theme to an existing figure and axes (legacy function name)"""
|
|
|
|
| 568 |
|
| 569 |
def apply_base_styling():
|
| 570 |
"""Apply the base styling to the current Streamlit page."""
|
| 571 |
+
if not _STREAMLIT_AVAILABLE:
|
| 572 |
+
return
|
| 573 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 574 |
|
| 575 |
def apply_home_page_styling():
|
| 576 |
"""Apply styling specific to the home page."""
|
| 577 |
+
if not _STREAMLIT_AVAILABLE:
|
| 578 |
+
return
|
| 579 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 580 |
st.markdown(get_home_page_css(), unsafe_allow_html=True)
|
| 581 |
|
| 582 |
def apply_documentation_page_styling():
|
| 583 |
"""Apply styling specific to the documentation page."""
|
| 584 |
+
if not _STREAMLIT_AVAILABLE:
|
| 585 |
+
return
|
| 586 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 587 |
st.markdown(get_documentation_page_css(), unsafe_allow_html=True)
|
| 588 |
|
| 589 |
def apply_tool_page_styling():
|
| 590 |
"""Apply styling specific to the analysis tool page."""
|
| 591 |
+
if not _STREAMLIT_AVAILABLE:
|
| 592 |
+
return
|
| 593 |
st.markdown(get_base_css(), unsafe_allow_html=True)
|
| 594 |
st.markdown(get_tool_page_css(), unsafe_allow_html=True)
|
st_logs/dashboard_access.log
ADDED
|
File without changes
|