jmontp commited on
Commit
0ba5e0d
·
1 Parent(s): 844f581

First commit

Browse files
Dockerfile DELETED
@@ -1,21 +0,0 @@
1
- FROM python:3.9-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
-
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
-
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,20 +0,0 @@
1
- ---
2
- title: Lower Limb Similarity Analysis
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/downhill_line_under.png ADDED

Git LFS Details

  • SHA256: ed8d79850c6b283d6c60a80ae46a41385ea9183d31ee4aef1a09974e08b95cb4
  • Pointer size: 130 Bytes
  • Size of remote file: 68.3 kB
assets/level_ground_line_under.png ADDED

Git LFS Details

  • SHA256: 7589f104316a75f39a999fd44dcc8167782de9de804526bb02827bf286df4bd0
  • Pointer size: 130 Bytes
  • Size of remote file: 72 kB
assets/level_ground_no_line_under.png ADDED

Git LFS Details

  • SHA256: b3f6754be0670347aa6ba1fe1476c47066cdf064473a7d1ceff8847ee56ea2c7
  • Pointer size: 130 Bytes
  • Size of remote file: 65.3 kB
cached_data/precalculated_stats.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:030a03b477c02ead69f0b0c83bfab2133d7f8bb9e2f81ab3ae09cb21b27fbd93
3
+ size 5573257
config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for dashboard and preprocessing scripts
2
+
3
+ AVAILABLE_SENSORS = [
4
+ 'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
5
+ 'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
6
+ 'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
7
+ 'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l'
8
+ ]
9
+
10
+ AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
11
+ 'stairs', 'sit_to_stand']
12
+
13
+ OUTPUTS_TOTAL = ['hip_torque_s_r', 'knee_torque_s_r', 'ankle_torque_s_r',
14
+ 'hip_power_s_r', 'knee_power_s_r', 'ankle_power_s_r']
15
+
16
+ # Not strictly needed by preprocess_data_for_hosting.py but good to keep related constants together.
17
+ ANALYSIS_ABSTRACTION_LEVELS = ['High', 'Medium/Low']
18
+
19
+ # Task configurations for pre-calculation and analysis
20
+ LOW_LEVEL_TASKS = [
21
+ ('stair_descent', None, None),
22
+ ('stair_ascent', None, None),
23
+ ('sit_to_stand', None, None),
24
+ ('level_walking', 0.0, 0.8),
25
+ ('level_walking', 0.0, 1.0),
26
+ ('level_walking', 0.0, 1.2),
27
+ # Gtech variants
28
+ ('level_walking', 0.0, 0.6),
29
+ ('level_walking', 0.0, 1.8),
30
+ ('incline_walking', 5.0, 0.8),
31
+ ('incline_walking', 5.0, 1.0),
32
+ ('incline_walking', 5.0, 1.2),
33
+ ('decline_walking', -5.0, 0.8),
34
+ ('decline_walking', -5.0, 1.0),
35
+ ('decline_walking', -5.0, 1.2),
36
+ ('incline_walking', 10.0, 0.8),
37
+ ('incline_walking', 10.0, 1.0),
38
+ ('incline_walking', 10.0, 1.2),
39
+ ('decline_walking', -10.0, 0.8),
40
+ ('decline_walking', -10.0, 1.0),
41
+ ('decline_walking', -10.0, 1.2),
42
+ ]
dashboard.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+ from datetime import datetime
4
+ import os
5
+ import sys
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from multivariate_gaussian_overlap import calculate_similarity_portrait_abstraction
10
+ from plot_similarity import plot_similarity_measure
11
+ from sensor_illustration import LegIllustration
12
+ import io
13
+
14
+ # Import constants from config.py
15
+ from config import AVAILABLE_SENSORS, AVAILABLE_TASKS, OUTPUTS_TOTAL, ANALYSIS_ABSTRACTION_LEVELS
16
+
17
+ # Set up logging
18
+ if not os.path.exists('st_logs'):
19
+ os.makedirs('st_logs')
20
+
21
+ # Create logger
22
+ logger = logging.getLogger('dashboard_logger')
23
+ logger.setLevel(logging.INFO)
24
+
25
+ # Create handlers
26
+ file_handler = logging.FileHandler('st_logs/dashboard_access.log')
27
+ console_handler = logging.StreamHandler(sys.stdout)
28
+
29
+ # Create formatter and add it to the handlers
30
+ formatter = logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
31
+ file_handler.setFormatter(formatter)
32
+ console_handler.setFormatter(formatter)
33
+
34
+ # Add the handlers to the logger
35
+ logger.addHandler(file_handler)
36
+ logger.addHandler(console_handler)
37
+
38
+ # Configure the page with custom CSS
39
+ st.set_page_config(page_title="Task Similarity Analysis", layout="wide")
40
+
41
+ # Custom CSS for better visual appeal
42
+ st.markdown("""
43
+ <style>
44
+ .main .block-container {
45
+ padding-top: 2rem;
46
+ padding-bottom: 2rem;
47
+ }
48
+ .stButton>button {
49
+ width: 100%;
50
+ margin-top: 1rem;
51
+ margin-bottom: 1rem;
52
+ }
53
+ .sidebar .sidebar-content {
54
+ padding-top: 1rem;
55
+ }
56
+ hr {
57
+ margin: 1rem 0;
58
+ }
59
+ h1 {
60
+ padding-bottom: 1rem;
61
+ border-bottom: 2px solid #f0f2f6;
62
+ }
63
+ h3 {
64
+ margin-top: 1.5rem;
65
+ }
66
+ /* Add tooltip styles */
67
+ .tooltip-container {
68
+ position: relative;
69
+ display: inline-block;
70
+ width: 100%;
71
+ }
72
+
73
+ .tooltip-content {
74
+ visibility: hidden;
75
+ background-color: white;
76
+ color: black;
77
+ text-align: left;
78
+ padding: 10px;
79
+ border-radius: 5px;
80
+ border: 1px solid #ddd;
81
+ position: absolute;
82
+ z-index: 1;
83
+ top: 100%;
84
+ left: 0;
85
+ margin-top: 5px;
86
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
87
+ width: 200px;
88
+ }
89
+
90
+ .tooltip-container:hover .tooltip-content {
91
+ visibility: visible;
92
+ }
93
+
94
+ .color-legend {
95
+ display: flex;
96
+ align-items: center;
97
+ margin: 5px 0;
98
+ }
99
+
100
+ .color-box {
101
+ width: 15px;
102
+ height: 15px;
103
+ margin-right: 10px;
104
+ border-radius: 3px;
105
+ }
106
+ </style>
107
+ """, unsafe_allow_html=True)
108
+
109
+ # Add authentication
110
+ if not st.session_state.get("authenticated", False):
111
+ password = st.text_input("Password:", type="password")
112
+ if password == "locolab!": # Replace with secure password
113
+ st.session_state.authenticated = True
114
+ # Log successful login with timestamp and IP
115
+ ip = st.get_client_ip() if hasattr(st, 'get_client_ip') else 'Unknown IP'
116
+ logger.info(f"Successful login from {ip}")
117
+ elif password: # Only log if a password attempt was made
118
+ # Log failed attempt
119
+ ip = st.get_client_ip() if hasattr(st, 'get_client_ip') else 'Unknown IP'
120
+ logger.warning(f"Failed login attempt from {ip} with password: {password}")
121
+ st.stop()
122
+ else:
123
+ st.stop()
124
+
125
+ # Define available sensors and tasks
126
+ # MOVED TO config.py: AVAILABLE_SENSORS, AVAILABLE_TASKS, OUTPUTS_TOTAL, ANALYSIS_ABSTRACTION_LEVELS
127
+
128
+ ANALYSIS_MODES = ['Similarity Analysis'] # This one seems specific to dashboard
129
+
130
+ # Sidebar controls with better organization
131
+ with st.sidebar:
132
+ st.title("Configuration")
133
+
134
+ # Add the generate button styling and button right after the title
135
+ st.markdown("""
136
+ <style>
137
+ div[data-testid="stButton"] button {
138
+ background-color: #28a745;
139
+ color: white;
140
+ border: none;
141
+ }
142
+ div[data-testid="stButton"] button:hover {
143
+ background-color: #218838;
144
+ color: white;
145
+ border: none;
146
+ }
147
+ </style>
148
+ """, unsafe_allow_html=True)
149
+
150
+ generate_button = st.button("Generate Visualization", key="generate_viz")
151
+
152
+ st.markdown("---") # Visual separator
153
+
154
+ # Modified task selection expander with abstraction controls
155
+ with st.expander("🎯 Task Selection", expanded=True):
156
+
157
+ # --- Analysis Detail moved to top ---
158
+ analysis_detail = st.selectbox(
159
+ "Analysis Detail",
160
+ ["High Level", "Medium Level", "Low Level"],
161
+ index=0, # Default to High Level
162
+ help="High: Compare all tasks globally. Medium: Select specific tasks, compare across all their conditions. Low: Select specific tasks and conditions (incline/speed)."
163
+ )
164
+
165
+ # --- Task Selection ---
166
+ # Disable task selection if High Level is chosen
167
+ tasks_disabled = (analysis_detail == "High Level")
168
+ task1 = st.selectbox("Task 1", AVAILABLE_TASKS, index=0, disabled=tasks_disabled)
169
+ task2 = st.selectbox("Task 2", AVAILABLE_TASKS, index=1, disabled=tasks_disabled)
170
+
171
+ # Convert to old terminology for backend compatibility - THIS IS MODIFIED LATER
172
+ # abstraction_level = "High" if analysis_detail == "Summary View" else "Medium"
173
+
174
+ # Initialize parameters - These will be overwritten based on analysis_detail
175
+ task1_incline = None
176
+ task1_speed = None
177
+ task2_incline = None
178
+ task2_speed = None
179
+
180
+ # Task parameter controls - only show if Low Level (previously Detailed View)
181
+ if analysis_detail == "Low Level":
182
+ st.markdown("**Task Specific Parameters**")
183
+ col1, col2 = st.columns(2)
184
+ with col1:
185
+ # Task 1 parameters
186
+ # Default to "All" which will become None later
187
+ task1_incline_select = "All"
188
+ if task1 != 'level_walking':
189
+ task1_incline_select = st.selectbox(
190
+ f"{task1} Incline",
191
+ options=["All", 5.0, 10.0],
192
+ index=0
193
+ )
194
+
195
+ task1_speed_select = st.selectbox(
196
+ f"{task1} Speed",
197
+ options=["All", 0.8, 1.0, 1.2],
198
+ index=0
199
+ )
200
+ with col2:
201
+ # Task 2 parameters
202
+ task2_incline_select = "All"
203
+ if task2 != 'level_walking':
204
+ task2_incline_select = st.selectbox(
205
+ f"{task2} Incline",
206
+ options=["All", 5.0, 10.0],
207
+ index=0
208
+ )
209
+
210
+ task2_speed_select = st.selectbox(
211
+ f"{task2} Speed",
212
+ options=["All", 0.8, 1.0, 1.2],
213
+ index=0
214
+ )
215
+
216
+ # Convert "All" selections to None for Low Level analysis
217
+ task1_incline = task1_incline_select if task1_incline_select != "All" else None
218
+ task1_speed = task1_speed_select if task1_speed_select != "All" else None
219
+ task2_incline = task2_incline_select if task2_incline_select != "All" else None
220
+ task2_speed = task2_speed_select if task2_speed_select != "All" else None
221
+
222
+ # No specific parameters needed for High or Medium level, they remain None
223
+
224
+
225
+ st.markdown("---") # Visual separator
226
+
227
+ # Sensor selection in an expander (moved second)
228
+ with st.expander("📡 Sensor Configuration", expanded=True):
229
+ # Add sensor illustration at the top of the sensor configuration
230
+ if 'selected_sensors' not in st.session_state:
231
+ st.session_state.selected_sensors = ['hip_angle_s_r']
232
+
233
+ # Create tooltip HTML
234
+ tooltip_html = f"""
235
+ <div class="tooltip-container">
236
+ <div style="width: 100%">
237
+ {{plot_placeholder}}
238
+ </div>
239
+ <div class="tooltip-content">
240
+ <div style="font-weight: bold; margin-bottom: 8px">Color Convention:</div>
241
+ <div class="color-legend">
242
+ <div class="color-box" style="background-color: green"></div>
243
+ <div>Angle Sensors</div>
244
+ </div>
245
+ <div class="color-legend">
246
+ <div class="color-box" style="background-color: blue"></div>
247
+ <div>Velocity Sensors</div>
248
+ </div>
249
+ <div class="color-legend">
250
+ <div class="color-box" style="background-color: red"></div>
251
+ <div>Angle + Velocity</div>
252
+ </div>
253
+ <div class="color-legend">
254
+ <div class="color-box" style="background-color: orange"></div>
255
+ <div>Torque Sensors</div>
256
+ </div>
257
+ </div>
258
+ </div>
259
+ """
260
+
261
+ # Create and display sensor illustration
262
+ sensor_fig = LegIllustration().draw_illustration(
263
+ highlighted_elements=st.session_state.selected_sensors,
264
+ gait_cycle_sections=st.session_state.get('phase_windows', []) # Default if not set
265
+ )
266
+
267
+ # Convert plot to HTML
268
+ from io import BytesIO
269
+ buf = BytesIO()
270
+ sensor_fig.savefig(buf, format='png', transparent=True, bbox_inches='tight')
271
+ buf.seek(0)
272
+ import base64
273
+ plot_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="width: 100%">'
274
+
275
+ # Insert plot into tooltip HTML and display
276
+ st.markdown(tooltip_html.replace('{plot_placeholder}', plot_html), unsafe_allow_html=True)
277
+
278
+ # Existing sensor selection
279
+ selected_sensors = st.multiselect(
280
+ "Select Sensors",
281
+ AVAILABLE_SENSORS,
282
+ default=['hip_angle_s_r']
283
+ )
284
+ # Update session state
285
+ st.session_state.selected_sensors = selected_sensors
286
+
287
+ # Phase window selection
288
+ st.markdown("### Phase Window")
289
+ use_specific_phases = st.toggle(
290
+ "Analyze Specific Gait Phases",
291
+ value=False,
292
+ help="Enable to select specific portions of the gait cycle for analysis",
293
+ key="phase_toggle" # Add key to track changes
294
+ )
295
+
296
+ if use_specific_phases:
297
+ st.markdown("""
298
+ Select gait cycle phases to analyze (0% = current time, -100% = past heel strike).
299
+ Multiple selections allow analysis of combined phases.
300
+ """)
301
+ # Create a list of all possible phase percentages (0 to 100%)
302
+ phase_options = [f"-{i/1.5:.1f}%" for i in range(151)]
303
+
304
+ # Create multiselect dropdown for specific phase offsets
305
+ selected_phase_strings = st.multiselect(
306
+ "Select Phase Offsets",
307
+ phase_options,
308
+ default=[phase_options[0]],
309
+ key="phase_multiselect" # Add key to track changes
310
+ )
311
+
312
+ # Convert selected percentages to phase offsets
313
+ phase_windows = [max(1, -int(float(p.strip('%')) * 1.49)) for p in selected_phase_strings]
314
+ if st.session_state.get('phase_windows') != phase_windows:
315
+ st.session_state.phase_windows = phase_windows
316
+ st.rerun() # Updated from experimental_rerun to rerun
317
+
318
+ else:
319
+ # Slider behavior - create list of all phases up to selected value
320
+ phase_window_percent = st.slider(
321
+ "Window Size (%)",
322
+ 0, 100, 1,
323
+ key="phase_slider" # Add key to track changes
324
+ )
325
+ max_window = max(1, int(phase_window_percent * 1.49))
326
+ phase_windows = list(range(1, max_window + 1))
327
+ if st.session_state.get('phase_windows') != phase_windows:
328
+ st.session_state.phase_windows = phase_windows
329
+ st.rerun() # Updated from experimental_rerun to rerun
330
+
331
+ if phase_window_percent > 0:
332
+ st.caption(f"Current window: {phase_window_percent:.1f}% ({max_window} frames)")
333
+
334
+ st.markdown("---")
335
+
336
+ # Visualization options in an expander (moved third)
337
+ with st.expander("📊 Visualization Options", expanded=True):
338
+ show_output_diff = st.checkbox("Show Output Differences", value=False)
339
+ show_conflict = st.checkbox("Show Input-Output Conflict", value=False)
340
+
341
+ # Add output selection if showing output differences or conflict
342
+ if show_output_diff or show_conflict:
343
+ selected_outputs = st.multiselect(
344
+ "Select Outputs",
345
+ OUTPUTS_TOTAL,
346
+ default=['ankle_torque_s_r'],
347
+ format_func=lambda x: x.replace('_s_r', '').replace('_', ' ').title()
348
+ )
349
+ if not selected_outputs:
350
+ st.warning("⚠️ Please select at least one output.")
351
+ st.stop()
352
+ # Update the OUTPUTS constant to use only selected outputs
353
+ OUTPUTS = selected_outputs
354
+
355
+ st.markdown("---")
356
+
357
+ # Marginal distribution options in an expander (moved last)
358
+ # Only show if neither output differences nor conflict is selected
359
+ if not (show_output_diff or show_conflict):
360
+ with st.expander("📈 Marginal Distribution", expanded=True):
361
+ show_marginals = st.checkbox("Show Marginal Distributions", value=False)
362
+ if show_marginals:
363
+ threshold = st.slider("Similarity Threshold", 0.0, 1.0, 0.5, 0.1)
364
+ else:
365
+ show_marginals = False
366
+
367
+ st.markdown("---") # Add a separator
368
+
369
+ # Main content
370
+ st.title("Task Similarity Analysis Dashboard")
371
+
372
+ if not selected_sensors:
373
+ st.warning("⚠️ Please select at least one sensor.")
374
+ st.stop()
375
+
376
+ def calculate_overlap_measures(task1, task2, sensors, abstraction_level,
377
+ task1_incline=None, task1_speed=None,
378
+ task2_incline=None, task2_speed=None,
379
+ time_windows=None,
380
+ use_output_data=False):
381
+ """
382
+ Calculate overlap measures with support for time windowing.
383
+ """
384
+ # Create a progress placeholder in Streamlit
385
+ progress_placeholder = st.empty()
386
+ progress_bar = progress_placeholder.progress(0)
387
+ status_placeholder = st.empty()
388
+
389
+ try:
390
+ # Show calculation status
391
+ # status_placeholder.text("Pre-computing task statistics...")
392
+
393
+ # Flip sign of incline if task is decline_walking
394
+ if task1 == 'decline_walking' and task1_incline is not None:
395
+ task1_incline = -task1_incline
396
+ if task2 == 'decline_walking' and task2_incline is not None:
397
+ task2_incline = -task2_incline
398
+
399
+ # Define task specifications
400
+ task1_spec = (task1, task1_incline, task1_speed)
401
+ task2_spec = (task2, task2_incline, task2_speed)
402
+
403
+
404
+ # Get the phase windows from session state or use default
405
+ if time_windows is None:
406
+ time_windows = [1]
407
+
408
+ similarity = calculate_similarity_portrait_abstraction(
409
+ sensors=list(sensors),
410
+ time_window=time_windows,
411
+ abstraction_level=abstraction_level.lower(),
412
+ task1_name=task1_spec,
413
+ task2_name=task2_spec,
414
+ output_difference=use_output_data,
415
+ progress_callback=lambda x: progress_bar.progress(x)
416
+ )
417
+
418
+ # Clear the progress indicators
419
+ progress_placeholder.empty()
420
+ status_placeholder.empty()
421
+
422
+ if similarity is None:
423
+ st.error("Error calculating similarity. Please check your task and sensor configuration.")
424
+ return None
425
+
426
+ return similarity
427
+
428
+ except ValueError as e:
429
+ # Clear progress indicators and show error
430
+ progress_placeholder.empty()
431
+ status_placeholder.empty()
432
+ st.error(f"This is not a valid task configuration: {e}")
433
+ return None
434
+ except Exception as e:
435
+ # Clear progress indicators and show error
436
+ progress_placeholder.empty()
437
+ status_placeholder.empty()
438
+ st.error(f"Unexpected error: {e}")
439
+ return None
440
+
441
+ def create_heatmap_matplotlib(data, title, plot_type='input'):
442
+ """Create a matplotlib heatmap using the library's plotting function"""
443
+ fig, ax = plt.subplots(figsize=(6, 6))
444
+
445
+ # Reshape data to 2D if needed
446
+ if len(data.shape) == 1:
447
+ data = data.reshape(-1, 1) # Convert to 2D column vector
448
+
449
+ plot_similarity_measure(
450
+ ax=ax,
451
+ measure_data=data,
452
+ plot_type=plot_type,
453
+ task_x_name=task1.replace('_', ' ').title(),
454
+ task_y_name=task2.replace('_', ' ').title()
455
+ )
456
+ ax.set_title(title)
457
+ return fig
458
+
459
+ def create_sensor_illustration(selected_sensors):
460
+ """Create sensor illustration with selected sensors"""
461
+ illustrator = LegIllustration()
462
+ fig = illustrator.draw_illustration(
463
+ highlighted_elements=selected_sensors,
464
+ gait_cycle_sections=st.session_state.get('phase_windows', []) # Default if not set
465
+ )
466
+ # Make figure smaller while maintaining aspect ratio
467
+ # Let the main plotting layout control size
468
+ # fig.set_size_inches(3, 6)
469
+ return fig
470
+
471
+ def create_downloadable_image(fig):
472
+ """Create a single image from the main figure object."""
473
+ # Remove old logic trying to combine multiple figures
474
+
475
+ # Save the provided figure directly to bytes
476
+ buf = io.BytesIO()
477
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
478
+ plt.close(fig) # Close the figure to free memory
479
+ buf.seek(0)
480
+ return buf.getvalue()
481
+
482
+ # Main section after authentication
483
+ if st.session_state.authenticated:
484
+ if generate_button:
485
+ with st.spinner("Doing some math..."):
486
+ # Determine backend parameters based on UI selections
487
+ # These variables (task1, task2, analysis_detail, task1_incline, etc.)
488
+ # are already defined from the sidebar widgets and their values are
489
+ # current when the button is pressed.
490
+
491
+ if analysis_detail == "High Level":
492
+ # For High Level, task names and conditions are None,
493
+ # task1 and task2 from UI are effectively ignored for backend call.
494
+ task1_name_backend = (None, None, None)
495
+ task2_name_backend = (None, None, None)
496
+ elif analysis_detail == "Medium Level":
497
+ # For Medium Level, specific tasks are chosen, but all their conditions.
498
+ # task1, task2 are the selected task strings from the UI.
499
+ # Incline and speed are None to signify "all conditions" for these tasks.
500
+ task1_name_backend = (task1, None, None)
501
+ task2_name_backend = (task2, None, None)
502
+ else: # "Low Level"
503
+ # For Low Level, specific tasks and their specific conditions are used.
504
+ # task1, task2 are task strings.
505
+ # task1_incline, task1_speed, etc., are specific values (or None if "All" was selected for that condition in the UI).
506
+ task1_name_backend = (task1, task1_incline, task1_speed)
507
+ task2_name_backend = (task2, task2_incline, task2_speed)
508
+
509
+ # This will be passed to calculate_overlap_measures, which then passes its .lower() version
510
+ # to calculate_similarity_portrait_abstraction.
511
+ abstraction_level_backend = analysis_detail
512
+
513
+ progress_placeholder = st.empty()
514
+ progress_placeholder.write(f"Calculating input similarity for {task1_name_backend[0] or 'all tasks'} and {task2_name_backend[0] or 'all tasks'}...")
515
+ # Calculate input similarity using determined backend parameters
516
+ input_similarity = calculate_overlap_measures(
517
+ task1=task1_name_backend[0],
518
+ task2=task2_name_backend[0],
519
+ sensors=selected_sensors,
520
+ abstraction_level=abstraction_level_backend,
521
+ task1_incline=task1_name_backend[1],
522
+ task1_speed=task1_name_backend[2],
523
+ task2_incline=task2_name_backend[1],
524
+ task2_speed=task2_name_backend[2],
525
+ time_windows=st.session_state.get('phase_windows', [])
526
+ )
527
+
528
+ output_difference_main = None
529
+ conflict_main = None
530
+ first_output_sensor = None
531
+
532
+ if input_similarity is not None:
533
+ print("Input similarity calculated, proceeding...")
534
+
535
+ # Calculate output difference and conflict for the FIRST selected output
536
+ if (show_output_diff or show_conflict):
537
+ if selected_outputs: # Check if any outputs were selected
538
+ first_output_sensor = selected_outputs[0]
539
+ progress_placeholder.empty()
540
+ progress_placeholder.write(f"Calculating output difference for {first_output_sensor}...")
541
+
542
+ # Calculate DIFFERENCE using determined backend parameters
543
+ output_difference_main = calculate_overlap_measures(
544
+ task1_name_backend[0], task2_name_backend[0],
545
+ [first_output_sensor], abstraction_level_backend, # USE CORRECT VARIABLE
546
+ task1_incline=task1_name_backend[1], task1_speed=task1_name_backend[2],
547
+ task2_incline=task2_name_backend[1], task2_speed=task2_name_backend[2],
548
+ use_output_data=True,
549
+ time_windows=[1] # Output difference usually uses time_window=1
550
+ )
551
+
552
+ if output_difference_main is not None and show_conflict:
553
+ progress_placeholder.empty()
554
+ progress_placeholder.write(f"Calculating conflict for {first_output_sensor}...")
555
+ # Conflict is input similarity * output difference
556
+ conflict_main = input_similarity * output_difference_main
557
+ else:
558
+ st.warning("⚠️ Please select at least one output sensor to show Output Differences or Conflict.")
559
+ # Ensure plots dependent on these are skipped
560
+ show_output_diff = False
561
+ show_conflict = False
562
+
563
+ # Clear progress text
564
+ progress_placeholder.empty()
565
+
566
+ # --- Start of New Plotting Layout ---
567
+ st.subheader("Similarity Analysis")
568
+
569
+ fig, axs = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=False) # Adjust figsize as needed
570
+
571
+ # --- Row 0: Output Info / Header ---
572
+
573
+ # Ax[0,0]: Output Leg Illustration
574
+ if (show_output_diff or show_conflict) and first_output_sensor:
575
+ leg_illustrator_out = LegIllustration()
576
+ leg_illustrator_out.draw_illustration(ax=axs[0,0], highlighted_elements=[first_output_sensor])
577
+ axs[0,0].set_title(f"Output: {first_output_sensor.replace('_s_r','').replace('_',' ').title()}")
578
+ else:
579
+ axs[0,0].axis('off')
580
+
581
+ # Ax[0,1]: Task Info Text
582
+ axs[0,1].axis('off') # Turn off axis lines and ticks
583
+ task_info_text = f"{abstraction_level_backend} Abstraction\n" # Use abstraction_level_backend
584
+ task1_display_incline = "All" if task1_name_backend[1] is None else task1_name_backend[1]
585
+ task1_display_speed = "All" if task1_name_backend[2] is None else task1_name_backend[2]
586
+ task2_display_incline = "All" if task2_name_backend[1] is None else task2_name_backend[1]
587
+ task2_display_speed = "All" if task2_name_backend[2] is None else task2_name_backend[2]
588
+
589
+ task1_display_name = task1_name_backend[0] or "All Tasks"
590
+ task2_display_name = task2_name_backend[0] or "All Tasks"
591
+
592
+ task_info_text += f"Task 1 (Y): {task1_display_name} (Incline: {task1_display_incline}, Speed: {task1_display_speed})\n"
593
+ task_info_text += f"Task 2 (X): {task2_display_name} (Incline: {task2_display_incline}, Speed: {task2_display_speed})"
594
+ axs[0,1].text(0.5, 0.5, task_info_text, ha='center', va='center', fontsize=10, wrap=True)
595
+
596
+ # Ax[0,2]: Output Difference Heatmap
597
+ if show_output_diff and output_difference_main is not None:
598
+ plot_similarity_measure(
599
+ ax=axs[0,2],
600
+ measure_data=output_difference_main,
601
+ plot_type='output',
602
+ cbar=True,
603
+ cbar_labels=True,
604
+ fontsize=10,
605
+ task_x_name=task2_name_backend[0], # Pass task names for potential use in plot_similarity_measure
606
+ task_y_name=task1_name_backend[0]
607
+ )
608
+ axs[0,2].set_title("Output Difference")
609
+ else:
610
+ axs[0,2].axis('off')
611
+
612
+ # --- Row 1: Input Info / Conflict ---
613
+
614
+ # Ax[1,0]: Input Leg Illustration
615
+ leg_illustrator_in = LegIllustration()
616
+ leg_illustrator_in.draw_illustration(
617
+ ax=axs[1,0],
618
+ highlighted_elements=selected_sensors,
619
+ gait_cycle_sections=st.session_state.get('phase_windows', [])
620
+ )
621
+ axs[1,0].set_title(f"Input Sensors")
622
+
623
+ # Ax[1,1]: Input Similarity Heatmap
624
+ plot_similarity_measure(
625
+ ax=axs[1,1],
626
+ measure_data=input_similarity,
627
+ plot_type='input',
628
+ cbar=True,
629
+ cbar_labels=True,
630
+ fontsize=10,
631
+ task_x_name=task2_name_backend[0],
632
+ task_y_name=task1_name_backend[0]
633
+ )
634
+ axs[1,1].set_title("Input Similarity")
635
+
636
+ # Ax[1,2]: Conflict Heatmap
637
+ if show_conflict and conflict_main is not None:
638
+ plot_similarity_measure(
639
+ ax=axs[1,2],
640
+ measure_data=conflict_main,
641
+ plot_type='conflict',
642
+ cbar=True,
643
+ cbar_labels=True,
644
+ fontsize=10,
645
+ task_x_name=task2_name_backend[0],
646
+ task_y_name=task1_name_backend[0]
647
+ )
648
+ axs[1,2].set_title("Input-Output Conflict")
649
+ else:
650
+ axs[1,2].axis('off')
651
+
652
+ # Adjust layout and display
653
+ plt.tight_layout(pad=1.5)
654
+ st.pyplot(fig)
655
+
656
+ # --- End of New Plotting Layout ---
657
+
658
+ # Download button using the new main figure
659
+ try:
660
+ image_data = create_downloadable_image(fig)
661
+ st.download_button(
662
+ label="Download Plot Layout",
663
+ data=image_data,
664
+ file_name="similarity_layout.png",
665
+ mime="image/png",
666
+ )
667
+ except Exception as e:
668
+ st.error(f"Failed to create downloadable image: {e}")
669
+
670
+ # Handle case where input similarity calculation failed
671
+ else:
672
+ st.error("Error calculating input similarity. Please check your task and sensor configuration.")
673
+
674
+ # Existing logic for marginals if needed (outside the main 2x3 grid)
675
+ if show_marginals:
676
+ st.markdown("--- Marginal Distributions ---")
677
+ # Placeholder: Add logic to plot marginal distributions if required.
678
+ # This was previously outside the main conditional blocks.
679
+ st.write("(Marginal distribution plotting not implemented in this layout yet)")
data_tools.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def generate_task_info_name(walking_type, incline, walking_speed):
4
+ """
5
+ Generate a standardized 'task_info' name string based on the walking type, incline, and walking speed.
6
+
7
+ The returned string follows the format:
8
+ For level walking: "{walking_type}_{formatted_speed}"
9
+ For incline/decline: "{walking_type}_{incline}_deg_{formatted_speed}"
10
+ where:
11
+ - walking_type is expected to be one of:
12
+ 'decline_walking', 'level_walking', or 'incline_walking'
13
+ - incline is a numerical value indicating ground incline in degrees.
14
+ - formatted_speed is a walking speed string, e.g., 's0x8' for 0.8 m/s, 's1' for 1.0 m/s, 's1x2' for 1.2 m/s.
15
+
16
+ Args:
17
+ walking_type (str): The type of walking. Expected values: 'decline_walking', 'level_walking', or 'incline_walking'.
18
+ incline (float or int): Ground incline in degrees.
19
+ walking_speed (float): Walking speed in m/s.
20
+
21
+ Returns:
22
+ str: The generated task_info name.
23
+
24
+ Examples:
25
+ >>> generate_task_info_name('decline_walking', -5, 0.8)
26
+ 'decline_walking_5_deg_s0x8'
27
+ >>> generate_task_info_name('level_walking', 0, 1.0)
28
+ 'level_walking_s1'
29
+ >>> generate_task_info_name('incline_walking', 5, 1.2)
30
+ 'incline_walking_5_deg_s1x2'
31
+ """
32
+
33
+ if 'stair' in walking_type:
34
+ return walking_type
35
+
36
+ if walking_type != 'level_walking' and (np.isnan(incline) or incline is None):
37
+ print(f"Incline is NaN for task {walking_type} {incline} {walking_speed}")
38
+ return None
39
+ if np.isnan(walking_speed) or walking_speed is None:
40
+ print(f"Walking speed is NaN for task {walking_type} {incline} {walking_speed}")
41
+ return None
42
+
43
+ def format_walking_speed(speed):
44
+ # Map known speeds to specific string representations
45
+ mapping = {0.8: "s0x8", 1.0: "s1", 1.2: "s1x2"}
46
+ # Use the mapping if available, otherwise default to a generic string
47
+ return mapping.get(speed, f"s{speed}")
48
+
49
+ # Format walking speed
50
+ speed_str = format_walking_speed(walking_speed)
51
+
52
+ # For level walking, return simpler format without incline info
53
+ if walking_type == 'level_walking':
54
+ return f"{walking_type}_{speed_str}"
55
+
56
+ # Format incline: if the value is an integer, do not show decimals.
57
+ if incline == int(incline):
58
+ incline_str = str(abs(int(incline)))
59
+ else:
60
+ incline_str = str(abs(incline))
61
+
62
+ # Build and return the standardized task_info string for incline/decline walking
63
+ task_info = f"{walking_type}_{incline_str}_deg_{speed_str}_m_s"
64
+ return task_info
multivariate_gaussian_overlap.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ """"
3
+ This script is meant to analyze how similar sensor inputs are to each other for
4
+ different tasks and potentially different subjects
5
+
6
+
7
+
8
+ Note: Task 1 represents the y axis in the heatmap.
9
+ Task 2 represents the x axis in the heatmap.
10
+ """
11
+
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ import itertools
15
+ from functools import partial
16
+ import pandas as pd
17
+ from typing import Optional, Union, Callable
18
+ import pickle
19
+ from data_tools import generate_task_info_name
20
+ import os
21
+ # from data_loader import total_data # Removed as total_data is no longer used directly
22
+ from functools import lru_cache
23
+ from scipy.stats import norm
24
+ # Import LOW_LEVEL_TASKS from config
25
+ from config import LOW_LEVEL_TASKS
26
+ # Add this at the top of your file, after the imports
27
+ os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
28
+
29
+ # ADD IMPORTS
30
+ import gzip
31
+ from pathlib import Path
32
+
33
+ # Load significance thresholds if available
34
+ try:
35
+ SIGNIFICANCE_THRESHOLDS = pd.read_csv('significance_thresholds.csv')
36
+ except FileNotFoundError:
37
+ SIGNIFICANCE_THRESHOLDS = None
38
+
39
+ # REMOVE task_subject_indices_file loading
40
+ # Load file with task subject indices
41
+ # try:
42
+ # with open('cached_data/task_subject_indices.pkl', 'rb') as f:
43
+ # task_subject_indices_file = pickle.load(f)
44
+ # except FileNotFoundError:
45
+ # print("Did not find any task subject indices. Current working directory: ",
46
+ # os.getcwd())
47
+ # task_subject_indices_file = None
48
+
49
+ # ADD HOSTING_STATS loading
50
+ HOSTING_STATS_PATH = Path("cached_data/precalculated_stats.pkl.gz") # Assumes script is run from task-similarity-analysis/ or it's in PYTHONPATH
51
+ try:
52
+ with gzip.open(HOSTING_STATS_PATH, "rb") as f:
53
+ HOSTING_STATS = pickle.load(f)
54
+ if HOSTING_STATS is None: # Should not happen if pickle.load succeeds
55
+ print(f"Warning: {HOSTING_STATS_PATH} loaded as None.")
56
+ # else:
57
+ # print(f"Successfully loaded {HOSTING_STATS_PATH}")
58
+ except FileNotFoundError:
59
+ print(f"ERROR: Could not load {HOSTING_STATS_PATH}. Please run preprocess_data_for_hosting.py first.")
60
+ HOSTING_STATS = None
61
+ except Exception as e:
62
+ print(f"Error loading {HOSTING_STATS_PATH}: {e}")
63
+ HOSTING_STATS = None
64
+
65
+
66
+ def calculate_task_statistics(
67
+ task_info_with_subject: tuple,
68
+ selected_sensors: list[str],
69
+ time_window_offsets: list[int], # Represents phase_windows from dashboard e.g. [1], or [1,2,3]
70
+ verbose: bool = False
71
+ ) -> Optional[tuple[np.ndarray, np.ndarray, int]]:
72
+ """Pre-compute task statistics for a given task and sensor configuration using HOSTING_STATS.
73
+
74
+ Args:
75
+ task_info_with_subject: Tuple of (task_name, incline, speed, subject)
76
+ selected_sensors: List of sensor names to use
77
+ time_window_offsets: List of 1-indexed phase windows (e.g., [1] for current, [1,2] for current and t-1).
78
+ Effectively, 1 -> roll_offset 0, 2 -> roll_offset 1, etc.
79
+ verbose: Whether to print verbose output
80
+ Returns:
81
+ Tuple of (final_means_array, final_covs_array, n_samples) if stats exist, None otherwise.
82
+ final_means_array shape: (150, len(selected_sensors) * num_actual_time_windows)
83
+ final_covs_array shape: (150, combined_num_features, combined_num_features) (diagonal)
84
+ n_samples: number of samples used for pre-calculation.
85
+ """
86
+ if HOSTING_STATS is None:
87
+ if verbose:
88
+ print("Error: HOSTING_STATS not loaded. Cannot calculate task statistics.")
89
+ return None
90
+
91
+ try:
92
+ subject_stats = HOSTING_STATS[task_info_with_subject]
93
+ means_all_features_at_t0 = subject_stats['means'] # (150, num_all_features_precalculated)
94
+ variances_all_features_at_t0 = subject_stats['variances'] # (150, num_all_features_precalculated)
95
+ n_samples = subject_stats['n_samples']
96
+ feature_order_at_preprocessing = subject_stats['feature_order']
97
+ except KeyError:
98
+ if verbose:
99
+ print(f"Warning: Statistics for {task_info_with_subject} not found in HOSTING_STATS. Skipping.")
100
+ return None
101
+ except TypeError: # If HOSTING_STATS is None or not a dict
102
+ if verbose:
103
+ print(f"Warning: HOSTING_STATS is not a valid dictionary for {task_info_with_subject}. Skipping.")
104
+ return None
105
+
106
+
107
+ try:
108
+ sensor_indices_in_precalc_order = [feature_order_at_preprocessing.index(s) for s in selected_sensors]
109
+ except ValueError as e:
110
+ if verbose:
111
+ print(f"Warning: One or more selected_sensors not found in preprocessed feature_order for {task_info_with_subject}: {e}. Selected: {selected_sensors}, Available: {feature_order_at_preprocessing}. Skipping.")
112
+ return None
113
+
114
+ means_selected_current_phase = means_all_features_at_t0[:, sensor_indices_in_precalc_order] # (150, len(selected_sensors))
115
+ variances_selected_current_phase = variances_all_features_at_t0[:, sensor_indices_in_precalc_order] # (150, len(selected_sensors))
116
+
117
+ # Determine actual roll offsets based on time_window_offsets from dashboard
118
+ # time_window_offsets (phase_windows from dashboard): [1] means current data (roll offset 0)
119
+ # [1, 2, 3] means current (offset 0), t-1 (offset 1), t-2 (offset 2)
120
+
121
+ actual_roll_offsets_to_apply = [] # stores 0 for current, 1 for t-1, etc.
122
+
123
+ # Ensure time_window_offsets is a list, default to [1] if empty or None (current time only)
124
+ current_time_window_offsets = time_window_offsets if time_window_offsets else [1]
125
+
126
+ for w in sorted(list(set(current_time_window_offsets))): # e.g., w from [1, 2, 3]
127
+ if w >= 1: # Dashboard sends 1-indexed windows
128
+ actual_roll_offsets_to_apply.append(w - 1) # Convert to 0-indexed roll offset
129
+
130
+ # Ensure actual_roll_offsets_to_apply is sorted and unique, typically [0], or [0,1,2] etc.
131
+ actual_roll_offsets_to_apply = sorted(list(set(actual_roll_offsets_to_apply)))
132
+ if not actual_roll_offsets_to_apply: # Should not happen if dashboard sends at least [1]
133
+ if verbose: print(f"Warning: No valid roll offsets from time_window_offsets {time_window_offsets}, defaulting to current time (offset 0).")
134
+ actual_roll_offsets_to_apply = [0]
135
+
136
+
137
+ final_means_list = []
138
+ final_variances_list = []
139
+
140
+ for roll_offset in actual_roll_offsets_to_apply:
141
+ if roll_offset == 0:
142
+ final_means_list.append(means_selected_current_phase.copy())
143
+ final_variances_list.append(variances_selected_current_phase.copy())
144
+ elif roll_offset > 0:
145
+ rolled_means = np.roll(means_selected_current_phase, shift=roll_offset, axis=0) # roll along phase axis
146
+ rolled_variances = np.roll(variances_selected_current_phase, shift=roll_offset, axis=0)
147
+ final_means_list.append(rolled_means)
148
+ final_variances_list.append(rolled_variances)
149
+ # else: negative roll_offset could mean future data, not typically used here.
150
+
151
+ if not final_means_list:
152
+ if verbose:
153
+ print(f"Critical Error: final_means_list is empty after processing time_window_offsets for {task_info_with_subject}. This should not happen.")
154
+ return None # Or raise error
155
+
156
+ final_means_array = np.concatenate(final_means_list, axis=1) # Concatenate along feature axis
157
+ final_variances_array = np.concatenate(final_variances_list, axis=1) # Concatenate along feature axis
158
+
159
+ num_combined_features = final_means_array.shape[1]
160
+ final_covs_array = np.zeros((150, num_combined_features, num_combined_features))
161
+ for i in range(150): # For each of the 150 phase points
162
+ final_covs_array[i] = np.diag(final_variances_array[i]) # Variances are for the combined features
163
+
164
+ return final_means_array, final_covs_array, n_samples
165
+
166
+ def apply_binary_threshold(portrait, n_samples_1, n_samples_2):
167
+ # Get minimum number of samples for this task pair
168
+ min_samples = min(n_samples_1, n_samples_2)
169
+ # Clip to max available threshold sample size
170
+ min_samples = min(min_samples, len(SIGNIFICANCE_THRESHOLDS))
171
+ # Get threshold for this sample size
172
+ threshold = SIGNIFICANCE_THRESHOLDS.iloc[min_samples-1]['threshold_95']
173
+ # Apply binary threshold
174
+ portrait[portrait < threshold] = 0
175
+ portrait[portrait >= threshold] = 1
176
+ return portrait
177
+
178
+ def process_pair(task_pair, task_1_stats, task_2_stats,
179
+ binary_threshold=False,
180
+ match_subjects: bool = False,
181
+ biomechanical_difference: bool = False):
182
+ """Process a single task pair using pre-computed statistics (vectorized path only)."""
183
+ task1, task2 = task_pair
184
+
185
+ # Initialize list to collect portraits for all subject combinations
186
+ subject_pairs = []
187
+ portraits = []
188
+
189
+ # Handle subject matching option
190
+ if match_subjects:
191
+ # Only compare the same subjects between tasks
192
+ common_subjects = set(task_1_stats[task1].keys()) & set(task_2_stats[task2].keys())
193
+ for subject in common_subjects:
194
+ subject_pairs.append((subject, subject))
195
+ means1, covs1, n_samples1 = task_1_stats[task1][subject]
196
+ means2, covs2, n_samples2 = task_2_stats[task2][subject]
197
+
198
+ # Calculate portrait for this subject - ALWAYS USE VECTORIZED PATH
199
+ portrait = vectorized_overlap_mine(means1, covs1, means2, covs2,
200
+ biomechanical_difference=biomechanical_difference)
201
+
202
+ if binary_threshold:
203
+ portrait = apply_binary_threshold(portrait, n_samples1, n_samples2)
204
+
205
+ portraits.append(portrait)
206
+ else:
207
+ # Compare all possible subject combinations between the two tasks
208
+ for subject1, stats1 in task_1_stats[task1].items():
209
+ for subject2, stats2 in task_2_stats[task2].items():
210
+ subject_pairs.append((subject1, subject2))
211
+ means1, covs1, n_samples1 = stats1
212
+ means2, covs2, n_samples2 = stats2
213
+
214
+ # Calculate portrait for this subject pair - ALWAYS USE VECTORIZED PATH
215
+ portrait = vectorized_overlap_mine(means1, covs1, means2, covs2,
216
+ biomechanical_difference=biomechanical_difference)
217
+
218
+ if binary_threshold:
219
+ portrait = apply_binary_threshold(portrait, n_samples1, n_samples2)
220
+
221
+ portraits.append(portrait)
222
+
223
+ # If no valid portraits were created, return None
224
+ if len(portraits) == 0:
225
+ return None
226
+ elif 'stair' in task1[0] or 'stair' in task2[0]:
227
+ pass
228
+ return portraits, subject_pairs
229
+
230
+ def vectorized_overlap_mine(means1, covs1, means2, covs2, tol=1e-12,
231
+ biomechanical_difference=False):
232
+ """
233
+ Vectorized version of the 'mine' method overlap measure computation.
234
+ Assumes covariance matrices (covs1, covs2) are diagonal.
235
+
236
+ Parameters:
237
+ means1 (np.ndarray): Array of means for task 1 with shape (150, d).
238
+ covs1 (np.ndarray): Array of covariance matrices for task 1 with shape (150, d, d). Must be diagonal.
239
+ means2 (np.ndarray): Array of means for task 2 with shape (150, d).
240
+ covs2 (np.ndarray): Array of covariance matrices for task 2 with shape (150, d, d). Must be diagonal.
241
+ tol (float): Tolerance value to determine if a matrix is singular.
242
+ biomechanical_difference (bool): If True, apply output-difference filtering per
243
+ the method: negligible–negligible (D* = 0), amplitude (D* = D), or sign reversal
244
+ (D* = D × P_diff_sign).
245
+
246
+ Returns:
247
+ np.ndarray: Overlap measures as a (150, 150) array.
248
+ """
249
+ d = means1.shape[1]
250
+ # Pairwise mean differences: (150,150,d)
251
+ diff = means1[:, None, :] - means2[None, :, :]
252
+ # Combined covariances: (150,150,d,d)
253
+ S = covs1[:, None, :, :] + covs2[None, :, :, :]
254
+ # output
255
+ overlap = np.zeros((150, 150))
256
+
257
+ # --- compute raw overlap ---
258
+ # Logic is now only for diagonal covariance
259
+ # S_diag was originally `np.diagonal(S, axis1=2, axis2=3)`
260
+ # Since S = covs1 + covs2 and covs are diagonal, S is diagonal.
261
+ # S_diag elements are just var1 + var2 for each feature.
262
+ # covs1 and covs2 are (150, d, d), but diagonal. We can extract their diagonals first.
263
+ vars1 = np.diagonal(covs1, axis1=1, axis2=2) # Shape (150, d)
264
+ vars2 = np.diagonal(covs2, axis1=1, axis2=2) # Shape (150, d)
265
+
266
+ # Pairwise sum of variances: (150, 150, d)
267
+ # vars1[:, None, :] gives (150, 1, d)
268
+ # vars2[None, :, :] gives (1, 150, d)
269
+ # Broadcasting sum_vars to (150, 150, d)
270
+ sum_vars = vars1[:, None, :] + vars2[None, :, :]
271
+
272
+ non_singular = np.all(sum_vars > tol, axis=2) # Check if all diagonal elements of (cov1+cov2) are > tol for each pair
273
+
274
+ if np.any(non_singular):
275
+ # Compute the quadratic form efficiently for diagonal matrices
276
+ # diff is (150,150,d). diff[non_singular] is (N,d) where N is number of non_singular pairs.
277
+ # sum_vars[non_singular] is (N,d).
278
+ quad_form_terms = diff[non_singular]**2 / sum_vars[non_singular] # (N, d)
279
+ quad_form = np.sum(quad_form_terms, axis=1) # (N,)
280
+
281
+ # Compute the overlap measure with underflow protection
282
+ max_exp = 20.0 # Maximum value before underflow in np.exp
283
+
284
+ overlap_values = np.zeros_like(quad_form)
285
+
286
+ # Indices where exponent is too large
287
+ large_exp_mask = (0.5 * quad_form) > max_exp
288
+
289
+ # Indices where exponent is not too large
290
+ non_large_exp_mask = ~large_exp_mask
291
+
292
+ overlap_values[non_large_exp_mask] = np.exp(-0.5 * quad_form[non_large_exp_mask])
293
+ # For large exponents, overlap_values remains 0 (default)
294
+
295
+ # Assign computed values to the overlap array
296
+ i_idx, j_idx = np.where(non_singular)
297
+ overlap[i_idx, j_idx] = overlap_values
298
+
299
+ # --- output‐difference filtering (vectorized) ---
300
+ if biomechanical_difference:
301
+ negligible_threshold = 0.1
302
+ ampable_threshold = 0.2
303
+
304
+ # optimized for scalar case (d=1)
305
+ std1 = np.sqrt(covs1[:, 0, 0])
306
+ std2 = np.sqrt(covs2[:, 0, 0])
307
+ ci_lo1 = means1[:, 0] - 1.96 * std1
308
+ ci_hi1 = means1[:, 0] + 1.96 * std1
309
+ ci_lo2 = means2[:, 0] - 1.96 * std2
310
+ ci_hi2 = means2[:, 0] + 1.96 * std2
311
+
312
+ z1 = means1[:, 0] / np.maximum(std1, tol)
313
+ z2 = means2[:, 0] / np.maximum(std2, tol)
314
+ Ppos1 = norm.cdf(z1); Pneg1 = 1.0 - Ppos1
315
+ Ppos2 = norm.cdf(z2); Pneg2 = 1.0 - Ppos2
316
+
317
+ negligible1 = (ci_lo1 >= -negligible_threshold) & (ci_hi1 <= negligible_threshold)
318
+ negligible2 = (ci_lo2 >= -negligible_threshold) & (ci_hi2 <= negligible_threshold)
319
+ ampable1 = np.abs(means1[:, 0]) > ampable_threshold
320
+ ampable2 = np.abs(means2[:, 0]) > ampable_threshold
321
+
322
+
323
+ i_idx, j_idx = np.nonzero(non_singular)
324
+ if i_idx.size > 0:
325
+ # negligible–negligible → zero overlap
326
+ m0 = negligible1[i_idx] & negligible2[j_idx]
327
+ # Since we calculate similarity, we set the similarity to 1
328
+ # which is turns makes the difference 0
329
+ overlap[i_idx[m0], j_idx[m0]] = 1.0
330
+
331
+ # amplitude conflicts → skip modification
332
+ m1 = (negligible1[i_idx] & ampable2[j_idx]) \
333
+ | (negligible2[j_idx] & ampable1[i_idx])
334
+
335
+ # sign-reversal scaling for the rest
336
+ m2 = ~(m0 | m1)
337
+ if np.any(m2):
338
+ with np.errstate(all='ignore'):
339
+ Pdiff = (
340
+ Ppos1[i_idx[m2]] * Pneg2[j_idx[m2]] +
341
+ Pneg1[i_idx[m2]] * Ppos2[j_idx[m2]]
342
+ )
343
+ # Since we calculate similarity, we have to scale the overlap
344
+ # and then convert it back to a similarity measure
345
+ overlap[i_idx[m2], j_idx[m2]] = 1-(1-overlap[i_idx[m2], j_idx[m2]])*Pdiff
346
+
347
+ # clamp to [0,1]
348
+ np.clip(overlap, 0.0, 1.0, out=overlap)
349
+ return overlap
350
+
351
+
352
+ def calculate_similarity_portrait_abstraction(sensors: Union[tuple, list], time_window: Union[int, list[int]] = 1,
353
+ task1_name: Union[tuple, str, None] = None,
354
+ task2_name: Union[tuple, str, None] = None,
355
+ verbose: bool = False,
356
+ n_jobs: int = -1, # n_jobs not used currently with sequential loop
357
+ return_all_portraits: bool = False, # return_all_portraits part of results_dict, not directly handled by agg
358
+ progress_callback: Optional[Callable[[float], None]] = None,
359
+ binary_threshold: bool = False,
360
+ output_difference: bool = False, # Renamed from use_output_data in dashboard
361
+ match_subjects: bool = False,
362
+ save_results_dict: bool = False,
363
+ biomechanical_difference: bool = False,
364
+ abstraction_level: Optional[str] = None): # abstraction_level passed from dashboard
365
+ """
366
+ Calculate input similarity between two tasks using multivariate Gaussian
367
+ based on a given abstraction level. Uses pre-calculated HOSTING_STATS.
368
+ Always uses the vectorized 'mine' overlap calculation method.
369
+
370
+ High level: Is there any conflict?
371
+ At this abstraction level, we want to understand if there is any
372
+ conflict for a given sensor configuration. Calculate the similarity
373
+ portrait for each task at the lowest level of abstraction and then
374
+ aggregate the results using the agg_method function (maximum of each
375
+ individual similarity portrait). The total similarity portrait is then
376
+ transposed and maxed back so that it is symmetric.
377
+
378
+ Medium/Low level: Is there any similarity between tasks with optional specificity?
379
+ At this abstraction level, we want to understand similarity between tasks
380
+ with optional specificity for incline and speed. For example:
381
+ - Compare all incline 5° to stairs up: task1_name=('incline', 5, None)
382
+ - Compare specific tasks: task1_name=('incline', 5, 0.8)
383
+ - Compare general tasks: task1_name=('incline', None, None)
384
+
385
+ Format: task_name = (task:str, incline:float|None, speed:float|None)
386
+ Setting incline/speed to None means match any value for that parameter.
387
+
388
+ At each level, the similarity portrait is calculated for each subject and
389
+ leg seperately and then aggregated using the agg_method function.
390
+
391
+ Arguments:
392
+ ----------
393
+ sensors: tuple or list
394
+ List of sensor names to use for similarity calculation
395
+ time_window: int or list, optional
396
+ List of 1-indexed phase windows (e.g., [1] for current, [1,2] for current and t-1).
397
+ Passed as time_window_offsets to calculate_task_statistics.
398
+ task1_name: tuple, optional
399
+ For low level: tuple of (task_name:str, incline:float|None,
400
+ speed:float|None)
401
+ Set incline/speed to None to match any value
402
+ task2_name: tuple, optional
403
+ Same format as task1_name
404
+ verbose: bool, optional
405
+ Whether to show progress bar during calculations (default=False)
406
+ n_jobs: int, optional
407
+ Number of CPU cores to use. -1 means use all available cores (default=-1). Currently not used.
408
+
409
+ return_all_portraits: bool, optional
410
+ If True, returns the full results_dict. If False (default), returns only the aggregated portrait.
411
+ binary_threshold: bool, optional
412
+ If True, threshold each low-level similarity portrait based on the 95%
413
+ confidence interval. Values below the threshold are set to 0, above to 1.
414
+ (default=False)
415
+ output_difference: bool, optional
416
+ If True, return the difference between the similarity portrait and 1
417
+ (default=False)
418
+ match_subjects: bool, optional
419
+ If True, match subjects between tasks (default=False)
420
+ save_results_dict: bool, optional
421
+ If True, save the results_dict to a file. Default False.
422
+ abstraction_level: str, optional
423
+ The abstraction level ("high", "medium", "low"). Used for logging/context, not for filtering within this function anymore.
424
+ """
425
+
426
+ # time_window is expected to be a list of 1-indexed phase windows from dashboard, e.g., [1], [1,2,3]
427
+ # It will be passed as time_window_offsets to calculate_task_statistics
428
+ # If it's an int, convert to list, though dashboard should pass list.
429
+ if isinstance(time_window, int):
430
+ # This interpretation might need review based on how dashboard calls it.
431
+ # If time_window=1 means "current only", it should become [1].
432
+ # If time_window=3 means "current, t-1, t-2", it should become [1,2,3].
433
+ # Let's assume int means number of windows starting from current.
434
+ time_window_offsets_list = list(range(1, time_window + 1)) if time_window > 0 else [1]
435
+ elif isinstance(time_window, list):
436
+ time_window_offsets_list = time_window if time_window else [1] # Default to [1] if empty list
437
+ else: # Default to current time if malformed
438
+ time_window_offsets_list = [1]
439
+
440
+
441
+ # MOVED TO config.py: low_level_tasks
442
+ unique_subjects_from_config_or_data = []
443
+ if HOSTING_STATS:
444
+ # Try to infer subjects from HOSTING_STATS keys if needed, though LOW_LEVEL_TASKS and iterating subjects is primary
445
+ all_subjects_in_stats = set()
446
+ for k_tuple in HOSTING_STATS.keys():
447
+ if isinstance(k_tuple, tuple) and len(k_tuple) == 4: # (task, inc, spd, subject)
448
+ all_subjects_in_stats.add(k_tuple[3])
449
+ unique_subjects_from_config_or_data = sorted(list(all_subjects_in_stats))
450
+ if not unique_subjects_from_config_or_data and verbose:
451
+ print("Warning: Could not infer subjects from HOSTING_STATS keys.")
452
+
453
+ # If subjects could not be inferred or HOSTING_STATS is None, this part will be problematic.
454
+ # The original code used total_data['subject'].unique(). This is no longer available.
455
+ # For now, let's assume LOW_LEVEL_TASKS and iterating these subjects is the way.
456
+ # The `preprocess_data_for_hosting.py` iterates `unique_subjects = total_data['subject'].unique()`.
457
+ # So `HOSTING_STATS` keys will contain all subjects that had data.
458
+ # `unique_subjects_from_config_or_data` should capture these.
459
+ # If this list is empty and HOSTING_STATS is populated, it means keys are not as expected.
460
+
461
+ # Fallback if subjects list is empty but stats exist (e.g. if keys are not tuples of 4)
462
+ # This part needs robust subject discovery if not relying on a predefined list.
463
+ # For now, assume unique_subjects_from_config_or_data is sufficient or that errors will propagate if it's empty.
464
+
465
+
466
+ # Pre-compute statistics for all valid tasks
467
+ if verbose or progress_callback:
468
+ print("Pre-computing task statistics...")
469
+ if progress_callback:
470
+ progress_callback(0.1) # Show initial progress
471
+
472
+ task_1_stats = {}
473
+ task_2_stats = {}
474
+
475
+ # Determine subjects to iterate over.
476
+ # If unique_subjects_from_config_or_data is empty, this loop won't run correctly.
477
+ subjects_to_iterate = unique_subjects_from_config_or_data
478
+ if not subjects_to_iterate and verbose:
479
+ print("Warning: subjects_to_iterate is empty. Statistics pre-computation might be skipped or fail.")
480
+ # Potentially, if LOW_LEVEL_TASKS is what defines the scope, and subjects are per task_config in HOSTING_STATS
481
+ # then subject iteration should be nested inside task_config iteration if HOSTING_STATS keys are task_config specific.
482
+ # The current `preprocess_data_for_hosting.py` creates keys `(task_config_tuple, subject_str)`.
483
+
484
+ total_combinations_to_calc_stats_for = 0
485
+ tasks_for_stats_calc_t1 = []
486
+ tasks_for_stats_calc_t2 = []
487
+
488
+ for task_info_cfg in LOW_LEVEL_TASKS: # task_info_cfg is (task_name, incline, speed)
489
+ # Filter based on task1_name and task2_name (from dashboard selection)
490
+ task1_str, task1_incline, task1_speed = task1_name if task1_name else (None, None, None)
491
+ task2_str, task2_incline, task2_speed = task2_name if task2_name else (None, None, None)
492
+
493
+ task_1_match = (task1_str is None or task1_str == task_info_cfg[0]) and \
494
+ (task1_incline is None or task1_incline == task_info_cfg[1]) and \
495
+ (task1_speed is None or task1_speed == task_info_cfg[2])
496
+
497
+ task_2_match = (task2_str is None or task2_str == task_info_cfg[0]) and \
498
+ (task2_incline is None or task2_incline == task_info_cfg[1]) and \
499
+ (task2_speed is None or task2_speed == task_info_cfg[2])
500
+
501
+ if task_1_match:
502
+ tasks_for_stats_calc_t1.append(task_info_cfg)
503
+ total_combinations_to_calc_stats_for += len(subjects_to_iterate)
504
+ if task_2_match: # Can overlap with task_1_match
505
+ # Avoid double counting if task_1_match also true
506
+ if not task_1_match or task_info_cfg not in tasks_for_stats_calc_t1:
507
+ tasks_for_stats_calc_t2.append(task_info_cfg) # only add if not already counted for t1
508
+ total_combinations_to_calc_stats_for += len(subjects_to_iterate)
509
+ elif task_1_match and task_info_cfg not in tasks_for_stats_calc_t2: # if it matched t1, ensure it's in t2 list for later use
510
+ tasks_for_stats_calc_t2.append(task_info_cfg)
511
+
512
+
513
+ # Recalculate total_combinations based on unique tasks for T1 and T2 sets
514
+ # This is complex if subjects vary per task_info in HOSTING_STATS.
515
+ # The simple approach: iterate LOW_LEVEL_TASKS, then subjects found in HOSTING_STATS for that task_info.
516
+
517
+ current_stat_calcs = 0
518
+ expected_stat_calcs = 0 # Will be determined by actual found (task_cfg, subj) pairs in HOSTING_STATS
519
+
520
+ # Populate task_1_stats
521
+ for task_config_t1 in tasks_for_stats_calc_t1:
522
+ task_1_stats[task_config_t1] = {}
523
+ for subject in subjects_to_iterate: # Iterate all potential subjects
524
+ task_info_w_subject = (*task_config_t1, subject)
525
+ if HOSTING_STATS and task_info_w_subject in HOSTING_STATS: # Check if this specific combo exists
526
+ expected_stat_calcs +=1 # Count expected calculations
527
+ stats = calculate_task_statistics(
528
+ task_info_w_subject, sensors, time_window_offsets_list, # Pass the list of offsets
529
+ verbose=verbose
530
+ )
531
+ if stats is not None:
532
+ task_1_stats[task_config_t1][subject] = stats
533
+ current_stat_calcs+=1
534
+ if progress_callback and expected_stat_calcs > 0 : progress_callback(0.1 + (0.4 * (current_stat_calcs / max(1,expected_stat_calcs*2)))) # *2 for T1 and T2 pass
535
+
536
+ # Populate task_2_stats
537
+ for task_config_t2 in tasks_for_stats_calc_t2:
538
+ task_2_stats[task_config_t2] = {}
539
+ for subject in subjects_to_iterate:
540
+ task_info_w_subject = (*task_config_t2, subject)
541
+ if HOSTING_STATS and task_info_w_subject in HOSTING_STATS:
542
+ if task_config_t2 not in tasks_for_stats_calc_t1 or subject not in task_1_stats.get(task_config_t2,{}): # Avoid re-calculating if already done for T1
543
+ expected_stat_calcs +=1
544
+
545
+ stats = calculate_task_statistics(
546
+ task_info_w_subject, sensors, time_window_offsets_list,
547
+ verbose=verbose
548
+ )
549
+ if stats is not None:
550
+ task_2_stats[task_config_t2][subject] = stats
551
+ current_stat_calcs+=1 # This current_stat_calcs will go up to expected_stat_calcs*2 potentially
552
+ if progress_callback and expected_stat_calcs > 0: progress_callback(0.1 + (0.4 * (current_stat_calcs / max(1,expected_stat_calcs*2))))
553
+
554
+
555
+ # Clean up empty task_configs from stats if no subjects had data
556
+ task_1_stats = {k: v for k, v in task_1_stats.items() if v}
557
+ task_2_stats = {k: v for k, v in task_2_stats.items() if v}
558
+
559
+ # Generate valid task pairs from successfully populated stats
560
+ task_pairs = list(itertools.product(task_1_stats.keys(),
561
+ task_2_stats.keys()))
562
+
563
+ if verbose or progress_callback:
564
+ if progress_callback:
565
+ progress_callback(0.5) # Half-way point
566
+ print(f"Computing similarity portraits for {len(task_pairs)} task pairs...")
567
+
568
+ # Process task pairs in parallel
569
+ process_func = partial(process_pair,
570
+ task_1_stats=task_1_stats,
571
+ task_2_stats=task_2_stats,
572
+ binary_threshold=binary_threshold,
573
+ match_subjects=match_subjects,
574
+ biomechanical_difference=biomechanical_difference)
575
+
576
+
577
+ # Add progress tracking if requested
578
+ results_dict = {}
579
+ num_results = 0
580
+ for i in tqdm(range(len(task_pairs)), desc="Processing task pairs",
581
+ total=len(task_pairs)):
582
+ current_task_pair = task_pairs[i]
583
+ result = process_func(current_task_pair)
584
+
585
+ # If the result is None, then skip it
586
+ if result is None:
587
+ # print(f"Result is None for task pair {current_task_pair}")
588
+ continue
589
+
590
+ if 'stair' in current_task_pair[0][0] or 'stair' in current_task_pair[1][0]:
591
+ pass
592
+
593
+ portraits, subject_pairs = result
594
+
595
+ low_level_task_1 = (current_task_pair[0][0],
596
+ current_task_pair[0][1],
597
+ current_task_pair[0][2])
598
+ low_level_task_2 = (current_task_pair[1][0],
599
+ current_task_pair[1][1],
600
+ current_task_pair[1][2])
601
+ low_level_task_total = (low_level_task_1, low_level_task_2)
602
+ # Add the result to the list of results for this low-level task
603
+ if low_level_task_total not in results_dict:
604
+ results_dict[low_level_task_total] = {}
605
+
606
+ if portraits is not None:
607
+ for subject_pair, portrait in zip(subject_pairs, portraits):
608
+
609
+ # Take the compliment if you are doing the output difference
610
+ if output_difference:
611
+ results_dict[low_level_task_total][subject_pair] = 1 - portrait
612
+ else:
613
+ results_dict[low_level_task_total][subject_pair] = portrait
614
+ num_results += 1
615
+ # Update progress if we requested
616
+ if progress_callback:
617
+ progress_callback(0.5 + (0.4 * (i / len(task_pairs))))
618
+
619
+ # Final progress update
620
+ if progress_callback:
621
+ progress_callback(0.9)
622
+
623
+ # Aggregate the similarity portraits
624
+ if num_results <= 0:
625
+ # raise ValueError("No valid similarity portraits were generated")
626
+ if verbose: print("Warning: No valid similarity portraits were generated. Check task/subject filters and data.")
627
+ return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150)) # Return empty/zero portrait
628
+
629
+ # Remove empty results (should be handled by num_results check, but good practice)
630
+ results_dict = {k: v for k, v in results_dict.items() if v and any(item is not None for item in v.values())}
631
+ if not results_dict and num_results > 0: # Contradiction, means num_results was miscounted or structure is wrong
632
+ if verbose: print("Warning: results_dict is empty despite num_results > 0. Returning zero portrait.")
633
+ return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150))
634
+ if not results_dict and num_results == 0: # Consistent, no results
635
+ return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150))
636
+
637
+
638
+ # Get the aggregation of the similarity portraits
639
+ total_task_similarity_portrait = low_level_tasks_aggregation(results_dict)
640
+
641
+ if progress_callback:
642
+ progress_callback(1.0) # Complete
643
+
644
+ # Add back save_results_dict logic
645
+ if save_results_dict:
646
+ directory = 'saved_similarity_portraits'
647
+ if not os.path.exists(directory):
648
+ os.makedirs(directory)
649
+ # Construct filename carefully. time_window_offsets_list might be long.
650
+ # Using a hash or a string representation that's filename-safe.
651
+ time_window_str = "-".join(map(str, time_window_offsets_list))
652
+ file_name = f'results_dict_{sorted(sensors)}_time_window_{time_window_str}'
653
+ if biomechanical_difference:
654
+ file_name += f'_biomechanical_difference_{biomechanical_difference}'
655
+ # Ensure filename is not too long and is valid
656
+ file_name = "".join(c if c.isalnum() or c in ('_', '-') else '' for c in file_name)[:100] + ".pkl"
657
+
658
+ try:
659
+ with open(os.path.join(directory, file_name), 'wb') as f:
660
+ pickle.dump(results_dict, f)
661
+ if verbose:
662
+ print(f"Saved results_dict to {os.path.join(directory, file_name)}")
663
+ except Exception as e:
664
+ if verbose:
665
+ print(f"Error saving results_dict: {e}")
666
+
667
+ if return_all_portraits:
668
+ return total_task_similarity_portrait, results_dict
669
+ else:
670
+ return total_task_similarity_portrait
671
+
672
+
673
+ def low_level_tasks_aggregation(results_dict: dict[tuple, list[np.ndarray]]
674
+ ) -> np.ndarray:
675
+ """
676
+ Aggregate the similarity portraits for low-level tasks.
677
+ First, we will get the average across all subjects for each low-level task.
678
+ Then, we will get the max of the average similarity portraits.
679
+ """
680
+ total_task_similarity_portraits = []
681
+ for low_level_task in results_dict.keys():
682
+
683
+ # Get the average across all subjects for each subject pair
684
+ results_list = list(results_dict[low_level_task].values())
685
+
686
+ # Get the average across all subjects for each low-level task
687
+ low_level_task_portraits = np.nanmean(results_list, axis=0)
688
+
689
+ # Append the average similarity portrait to the list
690
+ total_task_similarity_portraits.append(low_level_task_portraits)
691
+
692
+ # Get the max of the average similarity portraits
693
+ total_task_similarity_portrait = np.max(total_task_similarity_portraits,
694
+ axis=0)
695
+
696
+ if total_task_similarity_portrait is None:
697
+ pass
698
+
699
+ return total_task_similarity_portrait
700
+
701
+ def load_similarity_portrait(sensor_config: list[str],
702
+ time_window: list[int]=[0], # Default [0] might need to be [1] to match dashboard
703
+ task1_name: tuple=(None, None, None),
704
+ task2_name: tuple=(None, None, None),
705
+ biomechanical_difference: bool = False):
706
+ """
707
+ Load the similarity pickle file and calculate the similarity portrait for
708
+ a given sensor configuration, time window, and task names.
709
+
710
+ Arguments:
711
+ -----------
712
+ sensor_config : list[str]
713
+ The sensor configuration to use.
714
+ time_window : list[int]
715
+ The time window to use.
716
+ task1_name : tuple
717
+ Task name, incline, and speed for the first task. Use None for any value.
718
+ task2_name : tuple
719
+ Task name, incline, and speed for the second task. Use None for any value.
720
+
721
+ Returns:
722
+ --------
723
+ np.ndarray
724
+ """
725
+
726
+ # Get the file name
727
+ file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
728
+ if biomechanical_difference:
729
+ file_name += f'_biomechanical_difference_{biomechanical_difference}'
730
+ directory = 'saved_similarity_portraits'
731
+
732
+ # Load the results dict
733
+ with open(os.path.join(directory, f'{file_name}.pkl'), 'rb') as f:
734
+ results_dict = pickle.load(f)
735
+
736
+ # Check if we want to to a lower-levl task pair
737
+ if task1_name != (None, None, None) and task2_name != (None, None, None):
738
+ # Get the similarity portrait
739
+ try:
740
+ subject_portraits = results_dict[(task1_name, task2_name)]
741
+ except KeyError:
742
+ print(f"KeyError: {(task1_name, task2_name)} not found in results dict")
743
+ print(f"Results dict keys: {results_dict.keys()}")
744
+ raise KeyError(f"{(task1_name, task2_name)} not found in results dict")
745
+
746
+ similarity_portrait = np.mean(list(subject_portraits.values()), axis=0)
747
+
748
+ # We want to do the high level aggregation
749
+ else:
750
+ # Get the similarity portrait
751
+ similarity_portrait = low_level_tasks_aggregation(results_dict)
752
+
753
+ if similarity_portrait is None:
754
+ pass
755
+
756
+ return similarity_portrait
757
+
758
+
759
+ def load_results_dict(sensor_config: list[str],
760
+ time_window: list[int]=[0]): # Default [0] might need to be [1]
761
+ """
762
+ Load the results dict for a given sensor configuration, time window, and task names.
763
+ """
764
+ # Get the file name
765
+ file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
766
+ directory = 'saved_similarity_portraits'
767
+
768
+ # Load the results dict
769
+ with open(os.path.join(directory, f'{file_name}.pkl'), 'rb') as f:
770
+ results_dict = pickle.load(f)
771
+
772
+ return results_dict
773
+
774
+ def load_input_output_similarity_portrait(sensor_config: list[str],
775
+ output_config: list[str],
776
+ time_window: list[int]=[0], # Default [0] might need to be [1]
777
+ task1_name: tuple=(None, None, None),
778
+ task2_name: tuple=(None, None, None),
779
+ biomechanical_difference: bool = False):
780
+ """
781
+ Load the similarity pickle file and calculate the similarity portrait for
782
+ a given sensor configuration, time window, and task names.
783
+ """
784
+
785
+ # Load the input and output pickle files
786
+ directory = 'saved_similarity_portraits'
787
+ input_file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
788
+ output_file_name = f'results_dict_{sorted(output_config)}_time_window_{[0]}'
789
+ if biomechanical_difference:
790
+ output_file_name += f'_biomechanical_difference_{biomechanical_difference}'
791
+
792
+ input_results_dict = None
793
+ output_results_dict = None
794
+
795
+ # Load the input pickle file with error handling
796
+ try:
797
+ with open(os.path.join(directory, f'{input_file_name}.pkl'), 'rb') as f:
798
+ input_results_dict = pickle.load(f)
799
+ except FileNotFoundError:
800
+ print(f"Warning: Input similarity file not found: {input_file_name}.pkl")
801
+ # Return placeholder if input file is crucial and missing
802
+ # return np.zeros((150, 150))
803
+
804
+ # Load the output pickle file with error handling
805
+ try:
806
+ with open(os.path.join(directory, f'{output_file_name}.pkl'), 'rb') as f:
807
+ output_results_dict = pickle.load(f)
808
+ except FileNotFoundError:
809
+ print(f"Warning: Output similarity file not found: {output_file_name}.pkl")
810
+ # Decide if placeholder is needed if output file is missing
811
+ # If both are needed, might return placeholder here
812
+
813
+ # If either dict failed to load, return placeholder
814
+ if input_results_dict is None or output_results_dict is None:
815
+ print(f"Warning: Could not load required similarity data. Returning placeholder zeros.")
816
+ return np.zeros((150, 150))
817
+
818
+ # If the task1_name and task2_name are not None, then we want to do a
819
+ # lower-level task pair
820
+ try:
821
+ if task1_name != (None, None, None) and task2_name != (None, None, None):
822
+ target_key = (task1_name, task2_name)
823
+ input_results_dict = {target_key: input_results_dict[target_key]}
824
+ output_results_dict = {target_key: output_results_dict[target_key]}
825
+ except KeyError:
826
+ print(f"KeyError: {(task1_name, task2_name)} not found in input or output results dicts")
827
+ return None
828
+
829
+ # Calculate the input and output similarity portrait
830
+ input_output_similarity_portrait = input_output_task_aggregation(
831
+ input_results_dict, output_results_dict
832
+ )
833
+
834
+ return input_output_similarity_portrait
835
+
836
+ # Add back input_output_task_aggregation
837
+ def input_output_task_aggregation(
838
+ input_results_dict: dict[tuple, list[np.ndarray]],
839
+ output_results_dict: dict[tuple, list[np.ndarray]],
840
+ ) -> np.ndarray:
841
+ """
842
+ Aggregate the similarity portraits for input and output tasks.
843
+ """
844
+ # Check if the input and output results dicts have the same keys
845
+ if input_results_dict.keys() != output_results_dict.keys():
846
+ input_keys = set(input_results_dict.keys())
847
+ output_keys = set(output_results_dict.keys())
848
+ matching_keys = input_keys & output_keys
849
+ input_only_keys = input_keys - output_keys
850
+ output_only_keys = output_keys - input_keys
851
+
852
+ num_matching_keys = len(matching_keys)
853
+ num_input_only_keys = len(input_only_keys)
854
+ num_output_only_keys = len(output_only_keys)
855
+
856
+ # This print might be too verbose for regular operation
857
+ # print(f"Keys do not match: {num_matching_keys} keys matching, {num_input_only_keys} keys only in input, {num_output_only_keys} keys only in output")
858
+
859
+ # # Save the set of matching and not matching keys to a single file (for debugging)
860
+ # with open('matching_and_not_matching_keys.txt', 'w') as f:
861
+ # f.write(f"Matching keys:\n")
862
+ # for key in matching_keys:
863
+ # f.write(f"{key}\n")
864
+
865
+ # f.write(f"\nNot matching keys:\n")
866
+ # if input_only_keys:
867
+ # f.write(f"Keys only in input dictionary:\n")
868
+ # for key in input_only_keys:
869
+ # f.write(f"{key}\n")
870
+
871
+ # if output_only_keys:
872
+ # f.write(f"\nKeys only in output dictionary:\n")
873
+ # for key in output_only_keys:
874
+ # f.write(f"{key}\n")
875
+
876
+
877
+ common_keys = set(input_results_dict.keys()) & set(output_results_dict.keys())
878
+ if not common_keys:
879
+ # print("Warning: No common keys between input and output results dicts for aggregation.")
880
+ return np.zeros((150,150)) # Return a default empty portrait
881
+
882
+ total_task_similarity_portraits = []
883
+
884
+ # Get the average across all subjects for each low-level task
885
+ for low_level_task in common_keys:
886
+ if not input_results_dict[low_level_task] or not output_results_dict[low_level_task]:
887
+ # print(f"Warning: Empty subject data for low_level_task {low_level_task} in input or output dict. Skipping.")
888
+ continue
889
+
890
+ # Get the common subjects between the input and output results dicts for this specific low_level_task
891
+ common_subject_pairs = set(input_results_dict[low_level_task].keys()) & \
892
+ set(output_results_dict[low_level_task].keys())
893
+
894
+ if not common_subject_pairs:
895
+ # print(f"Warning: No common subject pairs for low_level_task {low_level_task}. Skipping.")
896
+ continue
897
+
898
+ # Take the product of the input and output similarity portraits for
899
+ # each subject pair
900
+ subject_input_output_portraits = []
901
+ for subject_pair in common_subject_pairs:
902
+ input_portrait = input_results_dict[low_level_task].get(subject_pair)
903
+ output_portrait = output_results_dict[low_level_task].get(subject_pair)
904
+
905
+ if input_portrait is not None and output_portrait is not None:
906
+ # Ensure they are numpy arrays for element-wise multiplication
907
+ input_portrait_np = np.array(input_portrait)
908
+ output_portrait_np = np.array(output_portrait)
909
+ if input_portrait_np.shape == output_portrait_np.shape:
910
+ subject_input_output_portraits.append(
911
+ input_portrait_np * output_portrait_np
912
+ )
913
+ # else:
914
+ # print(f"Warning: Shape mismatch for subject pair {subject_pair}, task {low_level_task}. Input: {input_portrait_np.shape}, Output: {output_portrait_np.shape}. Skipping.")
915
+ # else:
916
+ # print(f"Skipping subject pair {subject_pair} for low-level task {low_level_task} because one of the portraits is None")
917
+
918
+
919
+ if not subject_input_output_portraits:
920
+ # print(f"Warning: No valid subject_input_output_portraits for low_level_task {low_level_task}. Skipping.")
921
+ continue
922
+
923
+ # Get the mean of the subject-specific input-output product portraits
924
+ # Changed from max to mean as per original intent for product aggregation before overall max
925
+ mean_subject_input_output_portrait = np.nanmean(
926
+ np.array(subject_input_output_portraits), axis=0
927
+ )
928
+ if np.isnan(mean_subject_input_output_portrait).all():
929
+ # print(f"Warning: Mean subject input-output portrait is all NaN for {low_level_task}. Skipping.")
930
+ continue
931
+
932
+ # Append the aggregated similarity portrait for this low_level_task to the list
933
+ total_task_similarity_portraits.append(mean_subject_input_output_portrait)
934
+
935
+ if not total_task_similarity_portraits:
936
+ # print("Warning: No task similarity portraits to aggregate after processing all common keys.")
937
+ return np.zeros((150,150))
938
+
939
+ # Get the max across all low_level_task aggregated portraits
940
+ final_aggregated_portrait = np.nanmax(
941
+ np.array(total_task_similarity_portraits), axis=0
942
+ )
943
+ if np.isnan(final_aggregated_portrait).all():
944
+ # print("Warning: Final aggregated portrait is all NaN. Returning zeros.")
945
+ return np.zeros((150,150))
946
+
947
+ return final_aggregated_portrait
plot_similarity.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization functions for similarity measures"""
2
+ from plot_styling import PLOT_STYLE, PLOT_COLORS, set_plot_style
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import seaborn as sns
6
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
7
+ from typing import Optional, Union
8
+ import matplotlib as mpl
9
+ from matplotlib.colors import LinearSegmentedColormap
10
+
11
+ def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
12
+ plot_type: str = 'input', cbar: bool = True,
13
+ task_x_name: str = '', task_y_name: str = '',
14
+ xlabel: Optional[str] = None, ylabel: Optional[str] = None,
15
+ cmap: Optional[Union[str, mpl.colors.Colormap]] = None,
16
+ title: Optional[str] = None,
17
+ fontsize: int = 12,
18
+ y_label_pad: int = -13,
19
+ cbar_labels: bool = True,
20
+ cutoff_treshold:float=None,
21
+ high_level_plot: bool = False):
22
+ """Plot similarity measure with consistent styling.
23
+
24
+ Args:
25
+ ax: matplotlib axis object
26
+ measure_data: 2D array of similarity measures
27
+ plot_type: str, one of 'input', 'output', or 'conflict'
28
+ cbar: bool, whether to show colorbar
29
+ task_x_name: str, name of x-axis task (used if xlabel not provided)
30
+ task_y_name: str, name of y-axis task (used if ylabel not provided)
31
+ xlabel: str, optional custom x-axis label
32
+ ylabel: str, optional custom y-axis label
33
+ cmap: str or matplotlib colormap, optional custom colormap
34
+ title: str, optional custom title
35
+ fontsize: int, font size for all text elements
36
+ cbar_labels: bool, whether to show colorbar labels
37
+ cutoff_treshold: float, optional cutoff threshold that will count
38
+ as a conflict. If it is not None, the amount of values above this
39
+ threshold will be added as a percent text next to the colorbar.
40
+ If it is None, no text will be added.
41
+ high_level_plot: bool, whether to plot in a high-level format. This just
42
+ means that we subtract the diagonal from the thresholded values, since
43
+ this will trivially be 1 due to the fact that we compare the same task
44
+ to itself.
45
+ """
46
+
47
+ if ax is None:
48
+ fig, ax = plt.subplots()
49
+
50
+ # Set color based on plot type
51
+ if cmap is None:
52
+ if plot_type == 'input':
53
+ cmap = PLOT_COLORS['input_similarity']
54
+ elif plot_type == 'output':
55
+ cmap = PLOT_COLORS['output_difference']
56
+ elif plot_type == 'output_biomechanical':
57
+ cmap = PLOT_COLORS['output_biomechanical']
58
+ else: # conflict
59
+ cmap = PLOT_COLORS['conflict']
60
+
61
+ # Create heatmap with flipped y-axis
62
+ hm = sns.heatmap(np.flipud(measure_data), vmin=0, vmax=1, ax=ax, cmap=cmap,
63
+ cbar=False, rasterized=False)
64
+
65
+ # Configure axes
66
+ ax.set_xticks([0, 149])
67
+ ax.set_xticklabels(['0%', '100%'], fontsize=fontsize)
68
+ ax.set_yticks([0, 149])
69
+ ax.set_yticklabels(['100%', '0%'], fontsize=fontsize) # Flipped y-axis labels
70
+
71
+ # Set labels with consistent padding
72
+ if xlabel is not None:
73
+ ax.set_xlabel(f'Gait Cycle (%)\n{xlabel}', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
74
+ elif task_x_name:
75
+ ax.set_xlabel(f'Gait Cycle (%)\n{task_x_name}', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
76
+ else:
77
+ ax.set_xlabel('Gait Cycle (%)', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
78
+
79
+ if ylabel is not None:
80
+ ax.set_ylabel(f'{ylabel}\nGait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
81
+ elif task_y_name:
82
+ ax.set_ylabel(f'{task_y_name}\nGait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
83
+ else:
84
+ ax.set_ylabel('Gait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
85
+
86
+ # Configure ticks
87
+ ax.tick_params(axis='both', which='both',
88
+ length=PLOT_STYLE['tick_length'],
89
+ width=PLOT_STYLE['tick_width'],
90
+ pad=PLOT_STYLE['tick_pad'],
91
+ labelsize=fontsize)
92
+
93
+ # Ensure x-axis labels are horizontal
94
+ ax.tick_params(axis='x', rotation=0)
95
+
96
+ # Set the title
97
+ if title is not None:
98
+ ax.set_title(title, pad=2, fontsize=fontsize)
99
+
100
+ if cbar:
101
+ divider = make_axes_locatable(ax)
102
+ cax = divider.append_axes("right", size="5%", pad=0.05) # Increased from 2% to 5%
103
+ cbar_obj = plt.colorbar(hm.collections[0], cax=cax)
104
+ cbar_obj.outline.set_visible(False)
105
+ cbar_obj.ax.yaxis.set_ticks_position('right')
106
+ cbar_obj.ax.tick_params(length=0, labelsize=fontsize)
107
+ if cbar_labels:
108
+ cbar_obj.set_ticks([0, 1])
109
+ else:
110
+ cbar_obj.set_ticks([])
111
+
112
+ # Implement cutoff threshold annotation
113
+ if cutoff_treshold is not None:
114
+
115
+ # The actual cutoff will depend on the plot type
116
+ if plot_type == 'input':
117
+ true_cutoff = cutoff_treshold
118
+ elif plot_type == 'output':
119
+ true_cutoff = 1 - cutoff_treshold
120
+ elif plot_type == 'output_biomechanical':
121
+ true_cutoff = 1 - cutoff_treshold
122
+ else: # conflict
123
+ true_cutoff = cutoff_treshold * (1 - cutoff_treshold)
124
+
125
+
126
+ # Count percent of values above threshold
127
+ if high_level_plot:
128
+ # Subtract diagonal from thresholded values
129
+ mask = np.ones_like(measure_data, dtype=bool)
130
+ np.fill_diagonal(mask, False)
131
+ total = np.sum(mask)
132
+ above = np.sum((measure_data > true_cutoff) & mask)
133
+ else:
134
+ total = measure_data.size
135
+ above = np.sum(measure_data > true_cutoff)
136
+ percent = 100.0 * above / total if total > 0 else 0.0
137
+ # Format as e.g. "12.3% > 0.8"
138
+ annotation = f"{percent:.1f}% > {true_cutoff:g}"
139
+ # Place annotation to the right of the colorbar
140
+ cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
141
+ fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
142
+
143
+ # Set aspect ratio to equal
144
+ ax.set_aspect('equal')
145
+
146
+ return hm
147
+
148
+ # Create custom colormaps with transparency for low values
149
+ def create_transparent_cmap(colors):
150
+ # Convert hex colors to RGB tuples with alpha
151
+ rgb_colors = [(1,1,1,0)] # Start with transparent white
152
+ for hex_color in colors:
153
+ r = int(hex_color[1:3], 16)/255
154
+ g = int(hex_color[3:5], 16)/255
155
+ b = int(hex_color[5:7], 16)/255
156
+ rgb_colors.append((r, g, b, 1.0))
157
+ return LinearSegmentedColormap.from_list('custom_cmap', rgb_colors)
158
+
159
+ # Define color sequences
160
+ blue_colors = ['#F7FBFF', '#DEEBF7', '#C6DBEF', '#9ECAE1', '#6BAED6', '#4292C6', '#2171B5']
161
+ orange_colors = ['#FFF5EB', '#FEE6CE', '#FDD0A2', '#FDAE6B', '#FD8D3C', '#F16913', '#D94801']
162
+ green_colors = ['#F7FCF5', '#E5F5E0', '#C7E9C0', '#A1D99B', '#74C476', '#41AB5D', '#238B45']
163
+
164
+ sim_cmap = create_transparent_cmap(blue_colors) # Blue sequential
165
+ diff_cmap = create_transparent_cmap(orange_colors) # Peach sequential
166
+ conflict_cmap = create_transparent_cmap(green_colors) # Purple sequential
plot_styling.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for consistent plot styling across all visualizations"""
2
+
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import matplotlib.font_manager as fm
6
+
7
+ # Add Arial font
8
+ fm.fontManager.addfont('/usr/share/fonts/truetype/msttcorefonts/Arial.ttf')
9
+
10
+ # Plot styling constants
11
+ PLOT_STYLE = {
12
+ 'font_family': 'Arial',
13
+ 'font_size': 18,
14
+ 'title_size': 20,
15
+ 'label_size': 18,
16
+ 'tick_size': 15,
17
+ 'tick_length': 5,
18
+ 'tick_width': 0.5,
19
+ 'tick_pad': 2,
20
+ 'label_pad_x': -15,
21
+ 'label_pad_y': -35,
22
+ 'figure_dpi': 300,
23
+ 'aspect_ratio': 'equal',
24
+ 'subplot_wspace': 0.05,
25
+ 'subplot_hspace': 0.1
26
+ }
27
+
28
+ # Color constants
29
+ PLOT_COLORS = {
30
+ 'input_similarity': sns.color_palette('rocket', as_cmap=True),
31
+ 'output_difference': sns.cubehelix_palette(start=.2, rot=-.3, dark=0, light=0.85,
32
+ reverse=True, as_cmap=True),
33
+ 'conflict': sns.cubehelix_palette(start=2, rot=0, dark=0, light=0.85,
34
+ reverse=True, as_cmap=True),
35
+ # Define a cubehelix palette starting dark (dark=0) and ending light purple (light=0.85, start=2.8)
36
+ 'output_biomechanical': sns.cubehelix_palette(start=2.8, rot=0.4, dark=0, light=0.85,
37
+ reverse=True, as_cmap=True)
38
+ }
39
+
40
+
41
+ purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
42
+ reverse=True, as_cmap=True)
43
+ my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
44
+ reverse=True, as_cmap=True)
45
+
46
+ def set_plot_style():
47
+ """Set consistent plot styling across all figures"""
48
+ plt.rcParams['font.family'] = PLOT_STYLE['font_family']
49
+ plt.rcParams['font.size'] = PLOT_STYLE['font_size']
50
+ plt.rcParams['axes.labelsize'] = PLOT_STYLE['label_size']
51
+ plt.rcParams['axes.titlesize'] = PLOT_STYLE['title_size']
52
+ plt.rcParams['xtick.labelsize'] = PLOT_STYLE['tick_size']
53
+ plt.rcParams['ytick.labelsize'] = PLOT_STYLE['tick_size']
54
+ plt.rcParams['figure.dpi'] = PLOT_STYLE['figure_dpi']
55
+ plt.rcParams['figure.subplot.wspace'] = PLOT_STYLE['subplot_wspace']
56
+ plt.rcParams['figure.subplot.hspace'] = PLOT_STYLE['subplot_hspace']
requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
sensor_illustration.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from typing import List
5
+ from dataclasses import dataclass
6
+ from plot_styling import set_plot_style
7
+
8
+ @dataclass
9
+ class JointPosition:
10
+ """Stores the position of a joint in the illustration"""
11
+ name: str
12
+ x: float
13
+ y: float
14
+
15
+ @dataclass
16
+ class Segment:
17
+ """Stores the segment information connecting two joints"""
18
+ name: str
19
+ joint1: str
20
+ joint2: str
21
+
22
+ class LegIllustration:
23
+ def __init__(self, mode: str = "level_walking"):
24
+ """Initialize the leg illustration with a background image based on the specified mode
25
+
26
+ Args:
27
+ mode: A string specifying the background mode. Options are:
28
+ "level_walking" (default),
29
+ "level_walking_line_under", and
30
+ "decline_walking".
31
+ """
32
+ # Set plot style at initialization
33
+ set_plot_style()
34
+ self.mode = mode # store mode if needed later
35
+
36
+ # Map modes to background images
37
+ background_images = {
38
+ "level_walking": "assets/level_ground_no_line_under.png",
39
+ "level_walking_line_under": "assets/level_ground_line_under.png",
40
+ "decline_walking": "assets/downhill_line_under.png"
41
+ }
42
+ if mode in background_images:
43
+ self.background_image_path = background_images[mode]
44
+ else:
45
+ print(f"Warning: Unknown mode '{mode}'. Defaulting to level_walking.")
46
+ self.background_image_path = background_images["level_walking"]
47
+
48
+ # Set joint positions based on mode
49
+ if mode in ["level_walking"]:
50
+ self.image_vertical_offset = 0.02 # Space at bottom for gait cycle bar
51
+ self.joints_right = {
52
+ 'hip_r': JointPosition('hip_r', 0.45, 0.825 + self.image_vertical_offset),
53
+ 'knee_r': JointPosition('knee_r', 0.65, 0.46 + self.image_vertical_offset),
54
+ 'ankle_r': JointPosition('ankle_r', 0.77, 0.075 + self.image_vertical_offset),
55
+ 'toe_r': JointPosition('toe_r', 0.92, 0.057 + self.image_vertical_offset),
56
+ }
57
+
58
+ self.joints_left = {
59
+ 'hip_l': JointPosition('hip_l', 0.35, 0.825 + self.image_vertical_offset),
60
+ 'knee_l': JointPosition('knee_l', 0.235, 0.42 + self.image_vertical_offset),
61
+ 'ankle_l': JointPosition('ankle_l', 0.06, 0.12 + self.image_vertical_offset),
62
+ 'toe_l': JointPosition('toe_l', 0.17, 0.03 + self.image_vertical_offset),
63
+ }
64
+
65
+ elif mode == "level_walking_line_under":
66
+ self.image_vertical_offset = 0.01 # Slightly lower offset for decline walking
67
+ self.joints_right = {
68
+ 'hip_r': JointPosition('hip_r', 0.48, 0.82 + self.image_vertical_offset),
69
+ 'knee_r': JointPosition('knee_r', 0.62, 0.45 + self.image_vertical_offset),
70
+ 'ankle_r': JointPosition('ankle_r', 0.695, 0.08 + self.image_vertical_offset),
71
+ 'toe_r': JointPosition('toe_r', 0.80, 0.075 + self.image_vertical_offset),
72
+ }
73
+
74
+ self.joints_left = {
75
+ 'hip_l': JointPosition('hip_l', 0.38, 0.82 + self.image_vertical_offset),
76
+ 'knee_l': JointPosition('knee_l', 0.28, 0.39 + self.image_vertical_offset),
77
+ 'ankle_l': JointPosition('ankle_l', 0.18, 0.12 + self.image_vertical_offset),
78
+ 'toe_l': JointPosition('toe_l', 0.26, 0.04 + self.image_vertical_offset),
79
+ }
80
+
81
+ elif mode == "decline_walking":
82
+ self.image_vertical_offset = 0.01 # Slightly lower offset for decline walking
83
+ self.joints_right = {
84
+ 'hip_r': JointPosition('hip_r', 0.63, 0.84 + self.image_vertical_offset),
85
+ 'knee_r': JointPosition('knee_r', 0.72, 0.50 + self.image_vertical_offset),
86
+ 'ankle_r': JointPosition('ankle_r', 0.795, 0.14 + self.image_vertical_offset),
87
+ 'toe_r': JointPosition('toe_r', 0.93, 0.115 + self.image_vertical_offset),
88
+ }
89
+
90
+ self.joints_left = {
91
+ 'hip_l': JointPosition('hip_l', 0.53, 0.84 + self.image_vertical_offset),
92
+ 'knee_l': JointPosition('knee_l', 0.48, 0.55 + self.image_vertical_offset),
93
+ 'ankle_l': JointPosition('ankle_l', 0.19, 0.36 + self.image_vertical_offset),
94
+ 'toe_l': JointPosition('toe_l', 0.30, 0.27 + self.image_vertical_offset),
95
+ }
96
+ else:
97
+ # Fallback to level_walking settings for unknown mode values
98
+ self.image_vertical_offset = 0.09
99
+ self.joints_right = {
100
+ 'hip_r': JointPosition('hip_r', 0.47, 0.825 + self.image_vertical_offset),
101
+ 'knee_r': JointPosition('knee_r', 0.67, 0.46 + self.image_vertical_offset),
102
+ 'ankle_r': JointPosition('ankle_r', 0.79, 0.075 + self.image_vertical_offset),
103
+ 'toe_r': JointPosition('toe_r', 0.94, 0.057 + self.image_vertical_offset),
104
+ }
105
+ self.joints_left = {
106
+ 'hip_l': JointPosition('hip_l', 0.37, 0.825 + self.image_vertical_offset),
107
+ 'knee_l': JointPosition('knee_l', 0.255, 0.42 + self.image_vertical_offset),
108
+ 'ankle_l': JointPosition('ankle_l', 0.08, 0.12 + self.image_vertical_offset),
109
+ 'toe_l': JointPosition('toe_l', 0.19, 0.03 + self.image_vertical_offset),
110
+ }
111
+
112
+ # Define segments
113
+ self.segments = [
114
+ Segment('thigh', 'hip', 'knee'),
115
+ Segment('shank', 'knee', 'ankle'),
116
+ Segment('foot', 'ankle', 'toe')
117
+ ]
118
+
119
+ # Add CoP and vGRF parameters
120
+ self.cop_offset = -0.07 # Offset from foot line
121
+ self.vgrf_arrow_height = 0.15 # Height of the force arrow
122
+ self.vgrf_arrow_offset = 0.07 # Horizontal offset from foot
123
+
124
+ # Add color mapping for different sensor types
125
+ self.color_map = {
126
+ 'angle': 'yellow',
127
+ 'velocity': 'blue',
128
+ 'angle_velocity': 'purple',
129
+ 'torque': 'orange'
130
+ }
131
+
132
+ def _get_joint_name(self, base_name: str, joints: dict) -> str:
133
+ """
134
+ Helper method to get the correct joint name with suffix
135
+ based on the joints dictionary
136
+ """
137
+ suffix = '_r' if 'hip_r' in joints else '_l'
138
+ return f"{base_name}{suffix}"
139
+
140
+ def draw_illustration(self,
141
+ highlighted_elements: List[str] = None,
142
+ title: str = None,
143
+ gait_cycle_sections: List[int] | int = None,
144
+ ax: plt.Axes = None,
145
+ scale_factor: float = 1.0) -> plt.Figure:
146
+ """Draw the leg illustration with highlighted elements
147
+
148
+ Args:
149
+ highlighted_elements: List of elements to highlight
150
+ title: Title for the figure
151
+ gait_cycle_sections: List of gait cycle sections to highlight
152
+ ax: Matplotlib axes to draw on. If None, creates new figure.
153
+ """
154
+ # Track whether we created a new figure
155
+ created_new_figure = False
156
+
157
+ if ax is None:
158
+ # Use consistent figure size and make background transparent
159
+ fig = plt.figure(figsize=(3 * scale_factor, 6 * scale_factor), facecolor='none')
160
+ ax = fig.add_subplot(111)
161
+ created_new_figure = True
162
+ else:
163
+ fig = ax.figure
164
+
165
+ ax.set_facecolor('none') # Make axis background transparent
166
+
167
+ # Draw leg illustration with adjusted vertical extent
168
+ # Calculate padding based on scale_factor
169
+ padding = 0 # 0.001 * scale_factor # Increased padding
170
+ try:
171
+ img = plt.imread(self.background_image_path)
172
+
173
+ # Keep original image positioning to maintain alignment with joint markers
174
+ ax.imshow(
175
+ img,
176
+ extent=[-padding, 1 + padding, # expanded x-axis
177
+ self.image_vertical_offset - padding*2, # original y-axis bottom
178
+ 1 + padding*2 + self.image_vertical_offset] # expanded y-axis top
179
+ )
180
+
181
+ # Only extend the bottom of the axis limits to make room for phase window % if needed
182
+ bottom_limit = self.image_vertical_offset - padding*2
183
+ if gait_cycle_sections is not None and (
184
+ (isinstance(gait_cycle_sections, list) and len(gait_cycle_sections) > 1) or
185
+ (isinstance(gait_cycle_sections, np.ndarray) and len(gait_cycle_sections) > 1)
186
+ ):
187
+ bottom_limit = -0.08 # Lower bottom limit for phase window %
188
+
189
+ # Update the axis limits
190
+ ax.set_xlim(-padding, 1 + padding)
191
+ ax.set_ylim(bottom_limit, 1 + padding*2 + self.image_vertical_offset)
192
+ except Exception as e:
193
+ print(f"Warning: Could not load background image: {e}")
194
+ # Also adjust fallback y-limits if loading fails
195
+ if gait_cycle_sections is not None and (
196
+ (isinstance(gait_cycle_sections, list) and len(gait_cycle_sections) > 1) or
197
+ (isinstance(gait_cycle_sections, np.ndarray) and len(gait_cycle_sections) > 1)
198
+ ):
199
+ bottom_limit = -0.08
200
+ else:
201
+ bottom_limit = -padding*2
202
+ ax.set_xlim(-padding, 1 + padding)
203
+ ax.set_ylim(bottom_limit, 1 + padding*2)
204
+
205
+ # Draw legs
206
+ self._draw_leg(ax, self.joints_right, self.segments,
207
+ highlighted_elements, scale_factor)
208
+ self._draw_leg(ax, self.joints_left, self.segments,
209
+ highlighted_elements, scale_factor)
210
+
211
+ if title:
212
+ ax.set_title(title, pad=20)
213
+
214
+ # Draw phase window % if more than one section is provided
215
+ if gait_cycle_sections is not None:
216
+ # Convert to list if int
217
+ if isinstance(gait_cycle_sections, int):
218
+ sections = [gait_cycle_sections]
219
+ else:
220
+ sections = gait_cycle_sections
221
+ if len(sections) > 1:
222
+ self._draw_phase_window_percent(ax, sections, scale_factor)
223
+
224
+ ax.axis('off')
225
+
226
+ # Only apply figure-level settings if we created the figure
227
+ if created_new_figure:
228
+ fig.patch.set_alpha(0.0) # Make figure background transparent
229
+ fig.tight_layout()
230
+ return fig
231
+ return None
232
+
233
+ def _draw_phase_window_percent(self, ax, gait_cycle_sections: List[int], scale_factor: float = 1.0):
234
+ """Draw the phase window percent as a number under the legs"""
235
+ # The gait cycle is 150 points (0-149)
236
+ n_sections = len(gait_cycle_sections)
237
+ if n_sections > 1:
238
+ percent = n_sections / 150 * 100
239
+ # Place the text under the legs, centered
240
+ y_position = -0.04 # Slightly below the image
241
+ fontsize = 32 * scale_factor
242
+ ax.text(0.5, y_position, f"Phase window: {percent:.1f}%", ha='center', va='top', fontsize=fontsize, color='black')
243
+
244
+ def _determine_highlight_color(self, elements: List[str], location: str) -> str:
245
+ """
246
+ Determine highlight color based on the types of sensors present for a specific location
247
+
248
+ Args:
249
+ elements: List of sensor names to analyze
250
+ location: The joint or segment name to check (e.g., 'hip', 'thigh')
251
+
252
+ Returns:
253
+ str: Color name from color_map
254
+ """
255
+ if not elements:
256
+ return 'gray'
257
+
258
+ # Filter elements to only those containing this location
259
+ location_elements = [e for e in elements if location in e]
260
+ if not location_elements:
261
+ return 'gray'
262
+
263
+ has_torque = any('torque' in e for e in location_elements)
264
+ has_angle = any('angle' in e and 'vel' not in e for e in location_elements)
265
+ has_velocity = any('vel' in e for e in location_elements)
266
+
267
+ if has_torque:
268
+ return self.color_map['torque']
269
+ elif has_angle and has_velocity:
270
+ return self.color_map['angle_velocity']
271
+ elif has_velocity:
272
+ return self.color_map['velocity']
273
+ elif has_angle:
274
+ return self.color_map['angle']
275
+
276
+ return 'gray'
277
+
278
+ def _draw_leg(self, ax, joints: dict, segments: List[Segment],
279
+ highlighted_elements: List[str],
280
+ scale_factor: float = 1.0):
281
+ """Helper method to draw a single leg with highlights"""
282
+ suffix = '_r' if 'hip_r' in joints else '_l'
283
+
284
+ # Get all highlighted elements for this leg
285
+ leg_elements = [e for e in (highlighted_elements or [])
286
+ if e.endswith(suffix)]
287
+
288
+ # Calculate figure-relative sizes
289
+ fig_width, fig_height = ax.figure.get_size_inches()
290
+ # Scale marker size. Make it be a 12 marker size on a 3x6 figure
291
+ joint_size = int(np.sqrt(fig_width * fig_height) / 18 * 12 * 3.1 * scale_factor)
292
+ # Scale line width. Make it be a 3 line width on a 3x6 figure
293
+ segment_width = int(np.sqrt(fig_width * fig_height) / 18 * 3 * 3.1 * scale_factor)
294
+
295
+ # joint_size = 12
296
+ # segment_width = 3
297
+
298
+ # Draw segments
299
+ for segment in segments:
300
+ # Get the correct joint names with suffixes
301
+ joint1_name = self._get_joint_name(segment.joint1, joints)
302
+ joint2_name = self._get_joint_name(segment.joint2, joints)
303
+
304
+ joint1 = joints[joint1_name]
305
+ joint2 = joints[joint2_name]
306
+ x = [joint1.x, joint2.x]
307
+ y = [joint1.y, joint2.y]
308
+
309
+ # Get the segment name with suffix and check if it's highlighted
310
+ base_segment = segment.name # 'thigh', 'shank', or 'foot'
311
+ color = self._determine_highlight_color(leg_elements, base_segment)
312
+ ax.plot(x, y, '-', color=color, linewidth=segment_width, alpha=0.8)
313
+
314
+ # Draw joints (excluding toe)
315
+ for joint_name, joint in joints.items():
316
+ if not joint_name.startswith('toe'):
317
+ # Check if any sensor for this joint is highlighted
318
+ base_joint = joint_name.split('_')[0] # 'hip' from 'hip_r'
319
+ color = self._determine_highlight_color(leg_elements, base_joint)
320
+ ax.plot(joint.x, joint.y, 'o',
321
+ color=color if color != 'gray' else 'black',
322
+ markersize=joint_size)
323
+
324
+ # Draw CoP and vGRF only if mode doesn't have line under
325
+ if self.background_image_path.endswith('no_line_under.png'):
326
+ cop_name = f"cop_y{suffix}"
327
+ cop_h = highlighted_elements and cop_name in highlighted_elements
328
+
329
+ grf_name = f"grf_y{suffix}"
330
+ grf_h = highlighted_elements and grf_name in highlighted_elements
331
+
332
+ self._draw_cop_and_vgrf(
333
+ ax, joints, cop_highlighted=cop_h, grf_highlighted=grf_h
334
+ )
335
+
336
+ def _draw_cop_and_vgrf(self, ax, joints,
337
+ cop_highlighted: bool, grf_highlighted: bool):
338
+ """
339
+ Draw the center of pressure line and
340
+ vertical ground reaction force
341
+ """
342
+ ankle_name = self._get_joint_name('ankle', joints)
343
+ toe_name = self._get_joint_name('toe', joints)
344
+ ankle = joints[ankle_name]
345
+ toe = joints[toe_name]
346
+
347
+ # Calculate CoP line parallel to foot
348
+ dx = toe.x - ankle.x
349
+ dy = toe.y - ankle.y
350
+ length = np.sqrt(dx**2 + dy**2)
351
+
352
+ dx_norm = dx / length
353
+ dy_norm = dy / length
354
+
355
+ # Calculate perpendicular offset for CoP line
356
+ offset_x = -dy_norm * self.cop_offset
357
+ offset_y = dx_norm * self.cop_offset
358
+
359
+ # Draw CoP line
360
+ cop_x = [ankle.x + offset_x, toe.x + offset_x]
361
+ cop_y = [ankle.y + offset_y, toe.y + offset_y]
362
+ cop_color = 'red' if cop_highlighted else 'gray'
363
+ ax.plot(cop_x, cop_y, '--', color=cop_color, linewidth=2, alpha=0.8)
364
+
365
+ # Draw vGRF arrow
366
+ arrow_x = toe.x + self.vgrf_arrow_offset
367
+ arrow_base_y = toe.y - 0.02
368
+ grf_color = 'red' if grf_highlighted else 'gray'
369
+ ax.arrow(arrow_x, arrow_base_y,
370
+ 0, self.vgrf_arrow_height,
371
+ head_width=0.02,
372
+ head_length=0.04,
373
+ fc=grf_color, ec=grf_color,
374
+ alpha=0.8)
375
+
376
+ def create_sensor_config_illustrations() -> None:
377
+ """Create and display multiple sensor configurations with different background modes"""
378
+ sensor_configs = [
379
+ {
380
+ 'name': 'Hip Angle Only (Level Walking)',
381
+ 'highlighted_elements': ['hip_angle_s_r', 'grf_y_r'],
382
+ 'gait_cycle_sections': [0],
383
+ 'mode': 'level_walking'
384
+ },
385
+ {
386
+ 'name': 'Thigh Angle Only (Level Walking with Line Under)',
387
+ 'highlighted_elements': ['thigh_angle_s_r'],
388
+ 'gait_cycle_sections': [25, 26, 27],
389
+ 'mode': 'level_walking_line_under'
390
+ },
391
+ {
392
+ 'name': 'Hip Joint + Thigh Segment (Level Walking)',
393
+ 'highlighted_elements': ['hip_angle_s_r', 'thigh_vel_s_r', 'grf_y_r'],
394
+ 'gait_cycle_sections': [50, 51, 52],
395
+ 'mode': 'level_walking'
396
+ },
397
+ {
398
+ 'name': 'Multi Sensor (Decline Walking)',
399
+ 'highlighted_elements': [
400
+ 'hip_angle_s_r', # Joint sensor
401
+ 'thigh_vel_s_r', # Segment sensor
402
+ 'knee_torque_s_r', # Joint sensor
403
+ 'shank_angle_s_r', # Segment sensor
404
+ 'hip_angle_s_l',
405
+ 'hip_vel_l',
406
+ 'cop_r',
407
+ 'grf_y_r'
408
+ ],
409
+ 'gait_cycle_sections': [74, 75, 76],
410
+ 'mode': 'decline_walking'
411
+ }
412
+ ]
413
+
414
+ for config in sensor_configs:
415
+ # Instantiate a new LegIllustration for each sensor config with the specified mode
416
+ illustrator = LegIllustration(mode=config.get('mode', 'level_walking'))
417
+ fig = illustrator.draw_illustration(
418
+ highlighted_elements=config['highlighted_elements'],
419
+ title=config['name'],
420
+ gait_cycle_sections=config.get('gait_cycle_sections')
421
+ )
422
+ plt.show()
423
+
424
+ if __name__ == "__main__":
425
+ create_sensor_config_illustrations()
426
+ # %%
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))