jmontp commited on
Commit
a18421b
·
1 Parent(s): 4dc3a93

Add output portrait cache and streamlined tool sync

Browse files
cached_data/output_portraits/ankle_torque_s_r_raw.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7051294c00afffe8d35bb76e820cca5eedba1f2025e8876510cf0636339eea0e
3
+ size 255530992
cached_data/output_portraits/hip_torque_s_r_raw.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa0966e7dce85e2d2554911d0e82c03108019500eeec955cd1e34833d809c794
3
+ size 258544069
cached_data/output_portraits/knee_torque_s_r_raw.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cae696ab3f55b030038d21240aee12477b606c0ff73481ec6566912b82f5bae
3
+ size 270291469
pages/02_Tool.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import sys
3
  import io
4
  import logging
 
 
5
  from datetime import datetime
6
  from pathlib import Path
7
 
@@ -26,9 +28,15 @@ if str(ROOT_DIR) not in sys.path:
26
  sys.path.insert(0, str(ROOT_DIR))
27
 
28
  try:
29
- from multivariate_gaussian_overlap import calculate_portrait
 
 
 
30
  except ImportError:
31
- from multivariate_gaussian_overlap import calculate_portrait
 
 
 
32
 
33
  try:
34
  from plot_similarity import plot_similarity_measure
@@ -79,6 +87,142 @@ logger.addHandler(console_handler)
79
  # Configure the page
80
  st.set_page_config(page_title="Task Similarity Analysis Tool", layout="wide")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Add authentication (temporarily disabled)
83
  st.session_state.setdefault("authenticated", True)
84
  # if not st.session_state.get("authenticated", False):
@@ -218,7 +362,7 @@ with st.sidebar:
218
  phase_window_percent = st.slider(
219
  "Window Size (%)",
220
  0,
221
- 20,
222
  0,
223
  key="phase_slider",
224
  help="Set how much gait history to include. 0% keeps the portraits focused on the current phase only.",
@@ -474,45 +618,86 @@ if st.session_state.authenticated:
474
  if input_result is None:
475
  st.stop()
476
 
477
- output_result = calculate_overlap_measures(
478
- task1=task1_name_backend[0],
479
- task2=task2_name_backend[0],
480
- sensors=None,
481
- abstraction_level=analysis_detail,
482
- task1_incline=task1_name_backend[1],
483
- task1_speed=task1_name_backend[2],
484
- task2_incline=task2_name_backend[1],
485
- task2_speed=task2_name_backend[2],
486
- time_windows=[1],
487
- output_sensors=[first_output_sensor],
488
- biomechanical_filtering=use_biomechanical_filtering,
489
- return_details=True,
490
  )
491
 
492
- if output_result is None:
493
- st.stop()
494
 
495
- conflict_result = calculate_overlap_measures(
496
- task1=task1_name_backend[0],
497
- task2=task2_name_backend[0],
498
- sensors=selected_sensors,
499
- abstraction_level=analysis_detail,
500
- task1_incline=task1_name_backend[1],
501
- task1_speed=task1_name_backend[2],
502
- task2_incline=task2_name_backend[1],
503
- task2_speed=task2_name_backend[2],
504
- time_windows=phase_windows,
505
- output_sensors=[first_output_sensor],
506
- biomechanical_filtering=use_biomechanical_filtering,
507
- return_details=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  )
509
 
510
- if conflict_result is None:
511
- st.stop()
 
 
 
512
 
513
  input_portrait, input_details = input_result
514
- output_portrait, output_details = output_result
515
- conflict_portrait, conflict_details = conflict_result
 
 
 
 
 
 
 
516
 
517
  st.session_state.analysis_results = {
518
  'input_portrait': input_portrait,
@@ -528,6 +713,7 @@ if st.session_state.authenticated:
528
  'analysis_detail': analysis_detail,
529
  'task1_backend': task1_name_backend,
530
  'task2_backend': task2_name_backend,
 
531
  }
532
  st.success("Updated portraits with the latest configuration.")
533
 
@@ -536,6 +722,10 @@ if st.session_state.authenticated:
536
  if analysis_results:
537
  st.subheader("Similarity Analysis")
538
 
 
 
 
 
539
  display_phase_x = st.session_state.get('tool_phase_x_display', int(round(DEFAULT_PHASE_X / 1.5)))
540
  display_phase_y = st.session_state.get('tool_phase_y_display', int(round(DEFAULT_PHASE_Y / 1.5)))
541
 
 
2
  import sys
3
  import io
4
  import logging
5
+ import pickle
6
+ import gzip
7
  from datetime import datetime
8
  from pathlib import Path
9
 
 
28
  sys.path.insert(0, str(ROOT_DIR))
29
 
30
  try:
31
+ from multivariate_gaussian_overlap import (
32
+ calculate_portrait,
33
+ combine_similarity_difference_to_conflict,
34
+ )
35
  except ImportError:
36
+ from multivariate_gaussian_overlap import (
37
+ calculate_portrait,
38
+ combine_similarity_difference_to_conflict,
39
+ )
40
 
41
  try:
42
  from plot_similarity import plot_similarity_measure
 
87
  # Configure the page
88
  st.set_page_config(page_title="Task Similarity Analysis Tool", layout="wide")
89
 
90
+ # Cache helpers for high-level output portraits
91
+ OUTPUT_CACHE_DIR = ROOT_DIR / "cached_data" / "output_portraits"
92
+
93
+
94
+ def _ensure_output_cache_dir() -> Path:
95
+ OUTPUT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
96
+ return OUTPUT_CACHE_DIR
97
+
98
+
99
+ def _get_output_cache_path(output_sensor: str, biomechanical_filtering: bool) -> Path:
100
+ sanitized = output_sensor.replace(os.sep, "_")
101
+ suffix = "_bio" if biomechanical_filtering else "_raw"
102
+ return _ensure_output_cache_dir() / f"{sanitized}{suffix}.pkl.gz"
103
+
104
+
105
+ def load_high_level_output_portrait(output_sensor: str, biomechanical_filtering: bool):
106
+ if biomechanical_filtering:
107
+ return None
108
+
109
+ cache_path = _get_output_cache_path(output_sensor, biomechanical_filtering)
110
+ if not cache_path.exists():
111
+ return None
112
+
113
+ try:
114
+ with gzip.open(cache_path, "rb") as handle:
115
+ payload = pickle.load(handle)
116
+ except (OSError, pickle.PickleError, EOFError) as exc:
117
+ logger.warning(
118
+ "Failed to read cached portrait %s (%s); regenerating.",
119
+ cache_path.name,
120
+ exc,
121
+ )
122
+ cache_path.unlink(missing_ok=True)
123
+ return None
124
+
125
+ if (
126
+ payload.get("output_sensor") != output_sensor
127
+ or payload.get("biomechanical_filtering") != biomechanical_filtering
128
+ ):
129
+ return None
130
+
131
+ return payload.get("portrait"), payload.get("results_dict")
132
+
133
+
134
+ def save_high_level_output_portrait(
135
+ output_sensor: str,
136
+ biomechanical_filtering: bool,
137
+ portrait: np.ndarray,
138
+ results_dict: dict,
139
+ ):
140
+ if biomechanical_filtering:
141
+ return
142
+
143
+ cache_path = _get_output_cache_path(output_sensor, biomechanical_filtering)
144
+ tmp_path = cache_path.with_name(cache_path.name + ".tmp")
145
+ payload = {
146
+ "output_sensor": output_sensor,
147
+ "biomechanical_filtering": biomechanical_filtering,
148
+ "portrait": portrait,
149
+ "results_dict": results_dict,
150
+ "created_at": datetime.utcnow().isoformat(),
151
+ }
152
+
153
+ with gzip.open(tmp_path, "wb") as handle:
154
+ pickle.dump(payload, handle)
155
+
156
+ tmp_path.replace(cache_path)
157
+
158
+
159
+ def _task_matches_spec(spec: tuple, candidate: tuple) -> bool:
160
+ if spec is None or spec == (None, None, None):
161
+ return True
162
+
163
+ if candidate is None:
164
+ return False
165
+
166
+ name, incline, speed = spec
167
+ cand_name, cand_incline, cand_speed = candidate
168
+
169
+ if name and cand_name != name:
170
+ return False
171
+ if incline is not None and cand_incline != incline:
172
+ return False
173
+ if speed is not None and cand_speed != speed:
174
+ return False
175
+
176
+ return True
177
+
178
+
179
+ def filter_output_results_by_tasks(
180
+ results_dict: dict,
181
+ task1_spec: tuple,
182
+ task2_spec: tuple,
183
+ ):
184
+ if not results_dict:
185
+ return None
186
+
187
+ task_pairs = results_dict.get("task_pairs", [])
188
+ portraits = results_dict.get("portraits", [])
189
+
190
+ if not task_pairs or not portraits:
191
+ return None
192
+
193
+ filtered_pairs = []
194
+ filtered_portraits = []
195
+
196
+ for pair, portrait in zip(task_pairs, portraits):
197
+ t1_candidate, t2_candidate = pair
198
+ if _task_matches_spec(task1_spec, t1_candidate) and _task_matches_spec(task2_spec, t2_candidate):
199
+ filtered_pairs.append(pair)
200
+ filtered_portraits.append(portrait)
201
+
202
+ if not filtered_portraits:
203
+ return None
204
+
205
+ if len(filtered_portraits) == 1:
206
+ aggregated_portrait = filtered_portraits[0]
207
+ else:
208
+ aggregated_portrait = np.nanmax(np.stack(filtered_portraits), axis=0)
209
+
210
+ filtered_dict = {
211
+ "task_pairs": filtered_pairs,
212
+ "portraits": filtered_portraits,
213
+ }
214
+
215
+ for optional_key in (
216
+ "task_histogram",
217
+ "n_matched_pairs",
218
+ "unmatched_input_pairs",
219
+ "unmatched_output_pairs",
220
+ ):
221
+ if optional_key in results_dict:
222
+ filtered_dict[optional_key] = results_dict[optional_key]
223
+
224
+ return aggregated_portrait, filtered_dict
225
+
226
  # Add authentication (temporarily disabled)
227
  st.session_state.setdefault("authenticated", True)
228
  # if not st.session_state.get("authenticated", False):
 
362
  phase_window_percent = st.slider(
363
  "Window Size (%)",
364
  0,
365
+ 99,
366
  0,
367
  key="phase_slider",
368
  help="Set how much gait history to include. 0% keeps the portraits focused on the current phase only.",
 
618
  if input_result is None:
619
  st.stop()
620
 
621
+ cached_output = load_high_level_output_portrait(
622
+ first_output_sensor,
623
+ use_biomechanical_filtering,
 
 
 
 
 
 
 
 
 
 
624
  )
625
 
626
+ use_cache = not use_biomechanical_filtering
 
627
 
628
+ if cached_output is None:
629
+ if use_cache:
630
+ logger.info(
631
+ "Caching high-level output portrait for %s (biomech=%s)",
632
+ first_output_sensor,
633
+ use_biomechanical_filtering,
634
+ )
635
+ else:
636
+ logger.info(
637
+ "Biomechanical filtering enabled; skipping cache for %s",
638
+ first_output_sensor,
639
+ )
640
+
641
+ high_output_result = calculate_overlap_measures(
642
+ task1=None,
643
+ task2=None,
644
+ sensors=None,
645
+ abstraction_level="High Level",
646
+ task1_incline=None,
647
+ task1_speed=None,
648
+ task2_incline=None,
649
+ task2_speed=None,
650
+ time_windows=[1],
651
+ output_sensors=[first_output_sensor],
652
+ biomechanical_filtering=use_biomechanical_filtering,
653
+ return_details=True,
654
+ )
655
+
656
+ if high_output_result is None:
657
+ st.stop()
658
+
659
+ high_output_portrait, high_output_details = high_output_result
660
+ if use_cache:
661
+ save_high_level_output_portrait(
662
+ first_output_sensor,
663
+ use_biomechanical_filtering,
664
+ high_output_portrait,
665
+ high_output_details,
666
+ )
667
+ output_cache_status = "Recomputed output portrait (cache refreshed)."
668
+ else:
669
+ output_cache_status = "Biomechanical filtering active; caching disabled."
670
+ else:
671
+ logger.info(
672
+ "Loaded high-level output portrait from cache for %s (biomech=%s)",
673
+ first_output_sensor,
674
+ use_biomechanical_filtering,
675
+ )
676
+ high_output_portrait, high_output_details = cached_output
677
+ output_cache_status = "Loaded output portrait from cache."
678
+
679
+ filtered_output = filter_output_results_by_tasks(
680
+ high_output_details,
681
+ task1_name_backend,
682
+ task2_name_backend,
683
  )
684
 
685
+ if filtered_output is None:
686
+ output_portrait = high_output_portrait
687
+ output_details = high_output_details
688
+ else:
689
+ output_portrait, output_details = filtered_output
690
 
691
  input_portrait, input_details = input_result
692
+
693
+ try:
694
+ conflict_portrait, conflict_details = combine_similarity_difference_to_conflict(
695
+ input_details,
696
+ output_details,
697
+ )
698
+ except ValueError as exc:
699
+ st.error(f"Failed to combine input and output portraits: {exc}")
700
+ st.stop()
701
 
702
  st.session_state.analysis_results = {
703
  'input_portrait': input_portrait,
 
713
  'analysis_detail': analysis_detail,
714
  'task1_backend': task1_name_backend,
715
  'task2_backend': task2_name_backend,
716
+ 'output_cache_status': output_cache_status,
717
  }
718
  st.success("Updated portraits with the latest configuration.")
719
 
 
722
  if analysis_results:
723
  st.subheader("Similarity Analysis")
724
 
725
+ cache_msg = analysis_results.get('output_cache_status')
726
+ if cache_msg:
727
+ st.caption(cache_msg)
728
+
729
  display_phase_x = st.session_state.get('tool_phase_x_display', int(round(DEFAULT_PHASE_X / 1.5)))
730
  display_phase_y = st.session_state.get('tool_phase_y_display', int(round(DEFAULT_PHASE_Y / 1.5)))
731