Spaces:
Runtime error
Runtime error
First commit
Browse files- Dockerfile +0 -21
- README.md +0 -20
- assets/downhill_line_under.png +3 -0
- assets/level_ground_line_under.png +3 -0
- assets/level_ground_no_line_under.png +3 -0
- cached_data/precalculated_stats.pkl.gz +3 -0
- config.py +42 -0
- dashboard.py +679 -0
- data_tools.py +64 -0
- multivariate_gaussian_overlap.py +947 -0
- plot_similarity.py +166 -0
- plot_styling.py +56 -0
- requirements.txt +0 -3
- sensor_illustration.py +426 -0
- src/streamlit_app.py +0 -40
Dockerfile
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
FROM python:3.9-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
build-essential \
|
| 7 |
-
curl \
|
| 8 |
-
software-properties-common \
|
| 9 |
-
git \
|
| 10 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
-
|
| 12 |
-
COPY requirements.txt ./
|
| 13 |
-
COPY src/ ./src/
|
| 14 |
-
|
| 15 |
-
RUN pip3 install -r requirements.txt
|
| 16 |
-
|
| 17 |
-
EXPOSE 8501
|
| 18 |
-
|
| 19 |
-
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
-
|
| 21 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Lower Limb Similarity Analysis
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 8501
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
-
pinned: false
|
| 11 |
-
short_description: Streamlit template space
|
| 12 |
-
license: mit
|
| 13 |
-
---
|
| 14 |
-
|
| 15 |
-
# Welcome to Streamlit!
|
| 16 |
-
|
| 17 |
-
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
| 18 |
-
|
| 19 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/downhill_line_under.png
ADDED
|
Git LFS Details
|
assets/level_ground_line_under.png
ADDED
|
Git LFS Details
|
assets/level_ground_no_line_under.png
ADDED
|
Git LFS Details
|
cached_data/precalculated_stats.pkl.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:030a03b477c02ead69f0b0c83bfab2133d7f8bb9e2f81ab3ae09cb21b27fbd93
|
| 3 |
+
size 5573257
|
config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration file for dashboard and preprocessing scripts
|
| 2 |
+
|
| 3 |
+
AVAILABLE_SENSORS = [
|
| 4 |
+
'hip_angle_s_r', 'hip_angle_s_l', 'hip_vel_s_r', 'hip_vel_s_l',
|
| 5 |
+
'knee_angle_s_r', 'knee_angle_s_l', 'knee_vel_s_r', 'knee_vel_s_l',
|
| 6 |
+
'ankle_angle_s_r', 'ankle_angle_s_l', 'ankle_vel_s_r', 'ankle_vel_s_l',
|
| 7 |
+
'foot_angle_s_r', 'foot_angle_s_l', 'foot_vel_s_r', 'foot_vel_s_l'
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
AVAILABLE_TASKS = ['decline_walking', 'level_walking', 'incline_walking',
|
| 11 |
+
'stairs', 'sit_to_stand']
|
| 12 |
+
|
| 13 |
+
OUTPUTS_TOTAL = ['hip_torque_s_r', 'knee_torque_s_r', 'ankle_torque_s_r',
|
| 14 |
+
'hip_power_s_r', 'knee_power_s_r', 'ankle_power_s_r']
|
| 15 |
+
|
| 16 |
+
# Not strictly needed by preprocess_data_for_hosting.py but good to keep related constants together.
|
| 17 |
+
ANALYSIS_ABSTRACTION_LEVELS = ['High', 'Medium/Low']
|
| 18 |
+
|
| 19 |
+
# Task configurations for pre-calculation and analysis
|
| 20 |
+
LOW_LEVEL_TASKS = [
|
| 21 |
+
('stair_descent', None, None),
|
| 22 |
+
('stair_ascent', None, None),
|
| 23 |
+
('sit_to_stand', None, None),
|
| 24 |
+
('level_walking', 0.0, 0.8),
|
| 25 |
+
('level_walking', 0.0, 1.0),
|
| 26 |
+
('level_walking', 0.0, 1.2),
|
| 27 |
+
# Gtech variants
|
| 28 |
+
('level_walking', 0.0, 0.6),
|
| 29 |
+
('level_walking', 0.0, 1.8),
|
| 30 |
+
('incline_walking', 5.0, 0.8),
|
| 31 |
+
('incline_walking', 5.0, 1.0),
|
| 32 |
+
('incline_walking', 5.0, 1.2),
|
| 33 |
+
('decline_walking', -5.0, 0.8),
|
| 34 |
+
('decline_walking', -5.0, 1.0),
|
| 35 |
+
('decline_walking', -5.0, 1.2),
|
| 36 |
+
('incline_walking', 10.0, 0.8),
|
| 37 |
+
('incline_walking', 10.0, 1.0),
|
| 38 |
+
('incline_walking', 10.0, 1.2),
|
| 39 |
+
('decline_walking', -10.0, 0.8),
|
| 40 |
+
('decline_walking', -10.0, 1.0),
|
| 41 |
+
('decline_walking', -10.0, 1.2),
|
| 42 |
+
]
|
dashboard.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import logging
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
from multivariate_gaussian_overlap import calculate_similarity_portrait_abstraction
|
| 10 |
+
from plot_similarity import plot_similarity_measure
|
| 11 |
+
from sensor_illustration import LegIllustration
|
| 12 |
+
import io
|
| 13 |
+
|
| 14 |
+
# Import constants from config.py
|
| 15 |
+
from config import AVAILABLE_SENSORS, AVAILABLE_TASKS, OUTPUTS_TOTAL, ANALYSIS_ABSTRACTION_LEVELS
|
| 16 |
+
|
| 17 |
+
# Set up logging
|
| 18 |
+
if not os.path.exists('st_logs'):
|
| 19 |
+
os.makedirs('st_logs')
|
| 20 |
+
|
| 21 |
+
# Create logger
|
| 22 |
+
logger = logging.getLogger('dashboard_logger')
|
| 23 |
+
logger.setLevel(logging.INFO)
|
| 24 |
+
|
| 25 |
+
# Create handlers
|
| 26 |
+
file_handler = logging.FileHandler('st_logs/dashboard_access.log')
|
| 27 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 28 |
+
|
| 29 |
+
# Create formatter and add it to the handlers
|
| 30 |
+
formatter = logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
| 31 |
+
file_handler.setFormatter(formatter)
|
| 32 |
+
console_handler.setFormatter(formatter)
|
| 33 |
+
|
| 34 |
+
# Add the handlers to the logger
|
| 35 |
+
logger.addHandler(file_handler)
|
| 36 |
+
logger.addHandler(console_handler)
|
| 37 |
+
|
| 38 |
+
# Configure the page with custom CSS
|
| 39 |
+
st.set_page_config(page_title="Task Similarity Analysis", layout="wide")
|
| 40 |
+
|
| 41 |
+
# Custom CSS for better visual appeal
|
| 42 |
+
st.markdown("""
|
| 43 |
+
<style>
|
| 44 |
+
.main .block-container {
|
| 45 |
+
padding-top: 2rem;
|
| 46 |
+
padding-bottom: 2rem;
|
| 47 |
+
}
|
| 48 |
+
.stButton>button {
|
| 49 |
+
width: 100%;
|
| 50 |
+
margin-top: 1rem;
|
| 51 |
+
margin-bottom: 1rem;
|
| 52 |
+
}
|
| 53 |
+
.sidebar .sidebar-content {
|
| 54 |
+
padding-top: 1rem;
|
| 55 |
+
}
|
| 56 |
+
hr {
|
| 57 |
+
margin: 1rem 0;
|
| 58 |
+
}
|
| 59 |
+
h1 {
|
| 60 |
+
padding-bottom: 1rem;
|
| 61 |
+
border-bottom: 2px solid #f0f2f6;
|
| 62 |
+
}
|
| 63 |
+
h3 {
|
| 64 |
+
margin-top: 1.5rem;
|
| 65 |
+
}
|
| 66 |
+
/* Add tooltip styles */
|
| 67 |
+
.tooltip-container {
|
| 68 |
+
position: relative;
|
| 69 |
+
display: inline-block;
|
| 70 |
+
width: 100%;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.tooltip-content {
|
| 74 |
+
visibility: hidden;
|
| 75 |
+
background-color: white;
|
| 76 |
+
color: black;
|
| 77 |
+
text-align: left;
|
| 78 |
+
padding: 10px;
|
| 79 |
+
border-radius: 5px;
|
| 80 |
+
border: 1px solid #ddd;
|
| 81 |
+
position: absolute;
|
| 82 |
+
z-index: 1;
|
| 83 |
+
top: 100%;
|
| 84 |
+
left: 0;
|
| 85 |
+
margin-top: 5px;
|
| 86 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 87 |
+
width: 200px;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
.tooltip-container:hover .tooltip-content {
|
| 91 |
+
visibility: visible;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
.color-legend {
|
| 95 |
+
display: flex;
|
| 96 |
+
align-items: center;
|
| 97 |
+
margin: 5px 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.color-box {
|
| 101 |
+
width: 15px;
|
| 102 |
+
height: 15px;
|
| 103 |
+
margin-right: 10px;
|
| 104 |
+
border-radius: 3px;
|
| 105 |
+
}
|
| 106 |
+
</style>
|
| 107 |
+
""", unsafe_allow_html=True)
|
| 108 |
+
|
| 109 |
+
# Add authentication
|
| 110 |
+
if not st.session_state.get("authenticated", False):
|
| 111 |
+
password = st.text_input("Password:", type="password")
|
| 112 |
+
if password == "locolab!": # Replace with secure password
|
| 113 |
+
st.session_state.authenticated = True
|
| 114 |
+
# Log successful login with timestamp and IP
|
| 115 |
+
ip = st.get_client_ip() if hasattr(st, 'get_client_ip') else 'Unknown IP'
|
| 116 |
+
logger.info(f"Successful login from {ip}")
|
| 117 |
+
elif password: # Only log if a password attempt was made
|
| 118 |
+
# Log failed attempt
|
| 119 |
+
ip = st.get_client_ip() if hasattr(st, 'get_client_ip') else 'Unknown IP'
|
| 120 |
+
logger.warning(f"Failed login attempt from {ip} with password: {password}")
|
| 121 |
+
st.stop()
|
| 122 |
+
else:
|
| 123 |
+
st.stop()
|
| 124 |
+
|
| 125 |
+
# Define available sensors and tasks
|
| 126 |
+
# MOVED TO config.py: AVAILABLE_SENSORS, AVAILABLE_TASKS, OUTPUTS_TOTAL, ANALYSIS_ABSTRACTION_LEVELS
|
| 127 |
+
|
| 128 |
+
ANALYSIS_MODES = ['Similarity Analysis'] # This one seems specific to dashboard
|
| 129 |
+
|
| 130 |
+
# Sidebar controls with better organization
|
| 131 |
+
with st.sidebar:
|
| 132 |
+
st.title("Configuration")
|
| 133 |
+
|
| 134 |
+
# Add the generate button styling and button right after the title
|
| 135 |
+
st.markdown("""
|
| 136 |
+
<style>
|
| 137 |
+
div[data-testid="stButton"] button {
|
| 138 |
+
background-color: #28a745;
|
| 139 |
+
color: white;
|
| 140 |
+
border: none;
|
| 141 |
+
}
|
| 142 |
+
div[data-testid="stButton"] button:hover {
|
| 143 |
+
background-color: #218838;
|
| 144 |
+
color: white;
|
| 145 |
+
border: none;
|
| 146 |
+
}
|
| 147 |
+
</style>
|
| 148 |
+
""", unsafe_allow_html=True)
|
| 149 |
+
|
| 150 |
+
generate_button = st.button("Generate Visualization", key="generate_viz")
|
| 151 |
+
|
| 152 |
+
st.markdown("---") # Visual separator
|
| 153 |
+
|
| 154 |
+
# Modified task selection expander with abstraction controls
|
| 155 |
+
with st.expander("🎯 Task Selection", expanded=True):
|
| 156 |
+
|
| 157 |
+
# --- Analysis Detail moved to top ---
|
| 158 |
+
analysis_detail = st.selectbox(
|
| 159 |
+
"Analysis Detail",
|
| 160 |
+
["High Level", "Medium Level", "Low Level"],
|
| 161 |
+
index=0, # Default to High Level
|
| 162 |
+
help="High: Compare all tasks globally. Medium: Select specific tasks, compare across all their conditions. Low: Select specific tasks and conditions (incline/speed)."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# --- Task Selection ---
|
| 166 |
+
# Disable task selection if High Level is chosen
|
| 167 |
+
tasks_disabled = (analysis_detail == "High Level")
|
| 168 |
+
task1 = st.selectbox("Task 1", AVAILABLE_TASKS, index=0, disabled=tasks_disabled)
|
| 169 |
+
task2 = st.selectbox("Task 2", AVAILABLE_TASKS, index=1, disabled=tasks_disabled)
|
| 170 |
+
|
| 171 |
+
# Convert to old terminology for backend compatibility - THIS IS MODIFIED LATER
|
| 172 |
+
# abstraction_level = "High" if analysis_detail == "Summary View" else "Medium"
|
| 173 |
+
|
| 174 |
+
# Initialize parameters - These will be overwritten based on analysis_detail
|
| 175 |
+
task1_incline = None
|
| 176 |
+
task1_speed = None
|
| 177 |
+
task2_incline = None
|
| 178 |
+
task2_speed = None
|
| 179 |
+
|
| 180 |
+
# Task parameter controls - only show if Low Level (previously Detailed View)
|
| 181 |
+
if analysis_detail == "Low Level":
|
| 182 |
+
st.markdown("**Task Specific Parameters**")
|
| 183 |
+
col1, col2 = st.columns(2)
|
| 184 |
+
with col1:
|
| 185 |
+
# Task 1 parameters
|
| 186 |
+
# Default to "All" which will become None later
|
| 187 |
+
task1_incline_select = "All"
|
| 188 |
+
if task1 != 'level_walking':
|
| 189 |
+
task1_incline_select = st.selectbox(
|
| 190 |
+
f"{task1} Incline",
|
| 191 |
+
options=["All", 5.0, 10.0],
|
| 192 |
+
index=0
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
task1_speed_select = st.selectbox(
|
| 196 |
+
f"{task1} Speed",
|
| 197 |
+
options=["All", 0.8, 1.0, 1.2],
|
| 198 |
+
index=0
|
| 199 |
+
)
|
| 200 |
+
with col2:
|
| 201 |
+
# Task 2 parameters
|
| 202 |
+
task2_incline_select = "All"
|
| 203 |
+
if task2 != 'level_walking':
|
| 204 |
+
task2_incline_select = st.selectbox(
|
| 205 |
+
f"{task2} Incline",
|
| 206 |
+
options=["All", 5.0, 10.0],
|
| 207 |
+
index=0
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
task2_speed_select = st.selectbox(
|
| 211 |
+
f"{task2} Speed",
|
| 212 |
+
options=["All", 0.8, 1.0, 1.2],
|
| 213 |
+
index=0
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Convert "All" selections to None for Low Level analysis
|
| 217 |
+
task1_incline = task1_incline_select if task1_incline_select != "All" else None
|
| 218 |
+
task1_speed = task1_speed_select if task1_speed_select != "All" else None
|
| 219 |
+
task2_incline = task2_incline_select if task2_incline_select != "All" else None
|
| 220 |
+
task2_speed = task2_speed_select if task2_speed_select != "All" else None
|
| 221 |
+
|
| 222 |
+
# No specific parameters needed for High or Medium level, they remain None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
st.markdown("---") # Visual separator
|
| 226 |
+
|
| 227 |
+
# Sensor selection in an expander (moved second)
|
| 228 |
+
with st.expander("📡 Sensor Configuration", expanded=True):
|
| 229 |
+
# Add sensor illustration at the top of the sensor configuration
|
| 230 |
+
if 'selected_sensors' not in st.session_state:
|
| 231 |
+
st.session_state.selected_sensors = ['hip_angle_s_r']
|
| 232 |
+
|
| 233 |
+
# Create tooltip HTML
|
| 234 |
+
tooltip_html = f"""
|
| 235 |
+
<div class="tooltip-container">
|
| 236 |
+
<div style="width: 100%">
|
| 237 |
+
{{plot_placeholder}}
|
| 238 |
+
</div>
|
| 239 |
+
<div class="tooltip-content">
|
| 240 |
+
<div style="font-weight: bold; margin-bottom: 8px">Color Convention:</div>
|
| 241 |
+
<div class="color-legend">
|
| 242 |
+
<div class="color-box" style="background-color: green"></div>
|
| 243 |
+
<div>Angle Sensors</div>
|
| 244 |
+
</div>
|
| 245 |
+
<div class="color-legend">
|
| 246 |
+
<div class="color-box" style="background-color: blue"></div>
|
| 247 |
+
<div>Velocity Sensors</div>
|
| 248 |
+
</div>
|
| 249 |
+
<div class="color-legend">
|
| 250 |
+
<div class="color-box" style="background-color: red"></div>
|
| 251 |
+
<div>Angle + Velocity</div>
|
| 252 |
+
</div>
|
| 253 |
+
<div class="color-legend">
|
| 254 |
+
<div class="color-box" style="background-color: orange"></div>
|
| 255 |
+
<div>Torque Sensors</div>
|
| 256 |
+
</div>
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
# Create and display sensor illustration
|
| 262 |
+
sensor_fig = LegIllustration().draw_illustration(
|
| 263 |
+
highlighted_elements=st.session_state.selected_sensors,
|
| 264 |
+
gait_cycle_sections=st.session_state.get('phase_windows', []) # Default if not set
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Convert plot to HTML
|
| 268 |
+
from io import BytesIO
|
| 269 |
+
buf = BytesIO()
|
| 270 |
+
sensor_fig.savefig(buf, format='png', transparent=True, bbox_inches='tight')
|
| 271 |
+
buf.seek(0)
|
| 272 |
+
import base64
|
| 273 |
+
plot_html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="width: 100%">'
|
| 274 |
+
|
| 275 |
+
# Insert plot into tooltip HTML and display
|
| 276 |
+
st.markdown(tooltip_html.replace('{plot_placeholder}', plot_html), unsafe_allow_html=True)
|
| 277 |
+
|
| 278 |
+
# Existing sensor selection
|
| 279 |
+
selected_sensors = st.multiselect(
|
| 280 |
+
"Select Sensors",
|
| 281 |
+
AVAILABLE_SENSORS,
|
| 282 |
+
default=['hip_angle_s_r']
|
| 283 |
+
)
|
| 284 |
+
# Update session state
|
| 285 |
+
st.session_state.selected_sensors = selected_sensors
|
| 286 |
+
|
| 287 |
+
# Phase window selection
|
| 288 |
+
st.markdown("### Phase Window")
|
| 289 |
+
use_specific_phases = st.toggle(
|
| 290 |
+
"Analyze Specific Gait Phases",
|
| 291 |
+
value=False,
|
| 292 |
+
help="Enable to select specific portions of the gait cycle for analysis",
|
| 293 |
+
key="phase_toggle" # Add key to track changes
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if use_specific_phases:
|
| 297 |
+
st.markdown("""
|
| 298 |
+
Select gait cycle phases to analyze (0% = current time, -100% = past heel strike).
|
| 299 |
+
Multiple selections allow analysis of combined phases.
|
| 300 |
+
""")
|
| 301 |
+
# Create a list of all possible phase percentages (0 to 100%)
|
| 302 |
+
phase_options = [f"-{i/1.5:.1f}%" for i in range(151)]
|
| 303 |
+
|
| 304 |
+
# Create multiselect dropdown for specific phase offsets
|
| 305 |
+
selected_phase_strings = st.multiselect(
|
| 306 |
+
"Select Phase Offsets",
|
| 307 |
+
phase_options,
|
| 308 |
+
default=[phase_options[0]],
|
| 309 |
+
key="phase_multiselect" # Add key to track changes
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Convert selected percentages to phase offsets
|
| 313 |
+
phase_windows = [max(1, -int(float(p.strip('%')) * 1.49)) for p in selected_phase_strings]
|
| 314 |
+
if st.session_state.get('phase_windows') != phase_windows:
|
| 315 |
+
st.session_state.phase_windows = phase_windows
|
| 316 |
+
st.rerun() # Updated from experimental_rerun to rerun
|
| 317 |
+
|
| 318 |
+
else:
|
| 319 |
+
# Slider behavior - create list of all phases up to selected value
|
| 320 |
+
phase_window_percent = st.slider(
|
| 321 |
+
"Window Size (%)",
|
| 322 |
+
0, 100, 1,
|
| 323 |
+
key="phase_slider" # Add key to track changes
|
| 324 |
+
)
|
| 325 |
+
max_window = max(1, int(phase_window_percent * 1.49))
|
| 326 |
+
phase_windows = list(range(1, max_window + 1))
|
| 327 |
+
if st.session_state.get('phase_windows') != phase_windows:
|
| 328 |
+
st.session_state.phase_windows = phase_windows
|
| 329 |
+
st.rerun() # Updated from experimental_rerun to rerun
|
| 330 |
+
|
| 331 |
+
if phase_window_percent > 0:
|
| 332 |
+
st.caption(f"Current window: {phase_window_percent:.1f}% ({max_window} frames)")
|
| 333 |
+
|
| 334 |
+
st.markdown("---")
|
| 335 |
+
|
| 336 |
+
# Visualization options in an expander (moved third)
|
| 337 |
+
with st.expander("📊 Visualization Options", expanded=True):
|
| 338 |
+
show_output_diff = st.checkbox("Show Output Differences", value=False)
|
| 339 |
+
show_conflict = st.checkbox("Show Input-Output Conflict", value=False)
|
| 340 |
+
|
| 341 |
+
# Add output selection if showing output differences or conflict
|
| 342 |
+
if show_output_diff or show_conflict:
|
| 343 |
+
selected_outputs = st.multiselect(
|
| 344 |
+
"Select Outputs",
|
| 345 |
+
OUTPUTS_TOTAL,
|
| 346 |
+
default=['ankle_torque_s_r'],
|
| 347 |
+
format_func=lambda x: x.replace('_s_r', '').replace('_', ' ').title()
|
| 348 |
+
)
|
| 349 |
+
if not selected_outputs:
|
| 350 |
+
st.warning("⚠️ Please select at least one output.")
|
| 351 |
+
st.stop()
|
| 352 |
+
# Update the OUTPUTS constant to use only selected outputs
|
| 353 |
+
OUTPUTS = selected_outputs
|
| 354 |
+
|
| 355 |
+
st.markdown("---")
|
| 356 |
+
|
| 357 |
+
# Marginal distribution options in an expander (moved last)
|
| 358 |
+
# Only show if neither output differences nor conflict is selected
|
| 359 |
+
if not (show_output_diff or show_conflict):
|
| 360 |
+
with st.expander("📈 Marginal Distribution", expanded=True):
|
| 361 |
+
show_marginals = st.checkbox("Show Marginal Distributions", value=False)
|
| 362 |
+
if show_marginals:
|
| 363 |
+
threshold = st.slider("Similarity Threshold", 0.0, 1.0, 0.5, 0.1)
|
| 364 |
+
else:
|
| 365 |
+
show_marginals = False
|
| 366 |
+
|
| 367 |
+
st.markdown("---") # Add a separator
|
| 368 |
+
|
| 369 |
+
# Main content
|
| 370 |
+
st.title("Task Similarity Analysis Dashboard")
|
| 371 |
+
|
| 372 |
+
if not selected_sensors:
|
| 373 |
+
st.warning("⚠️ Please select at least one sensor.")
|
| 374 |
+
st.stop()
|
| 375 |
+
|
| 376 |
+
def calculate_overlap_measures(task1, task2, sensors, abstraction_level,
|
| 377 |
+
task1_incline=None, task1_speed=None,
|
| 378 |
+
task2_incline=None, task2_speed=None,
|
| 379 |
+
time_windows=None,
|
| 380 |
+
use_output_data=False):
|
| 381 |
+
"""
|
| 382 |
+
Calculate overlap measures with support for time windowing.
|
| 383 |
+
"""
|
| 384 |
+
# Create a progress placeholder in Streamlit
|
| 385 |
+
progress_placeholder = st.empty()
|
| 386 |
+
progress_bar = progress_placeholder.progress(0)
|
| 387 |
+
status_placeholder = st.empty()
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
# Show calculation status
|
| 391 |
+
# status_placeholder.text("Pre-computing task statistics...")
|
| 392 |
+
|
| 393 |
+
# Flip sign of incline if task is decline_walking
|
| 394 |
+
if task1 == 'decline_walking' and task1_incline is not None:
|
| 395 |
+
task1_incline = -task1_incline
|
| 396 |
+
if task2 == 'decline_walking' and task2_incline is not None:
|
| 397 |
+
task2_incline = -task2_incline
|
| 398 |
+
|
| 399 |
+
# Define task specifications
|
| 400 |
+
task1_spec = (task1, task1_incline, task1_speed)
|
| 401 |
+
task2_spec = (task2, task2_incline, task2_speed)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# Get the phase windows from session state or use default
|
| 405 |
+
if time_windows is None:
|
| 406 |
+
time_windows = [1]
|
| 407 |
+
|
| 408 |
+
similarity = calculate_similarity_portrait_abstraction(
|
| 409 |
+
sensors=list(sensors),
|
| 410 |
+
time_window=time_windows,
|
| 411 |
+
abstraction_level=abstraction_level.lower(),
|
| 412 |
+
task1_name=task1_spec,
|
| 413 |
+
task2_name=task2_spec,
|
| 414 |
+
output_difference=use_output_data,
|
| 415 |
+
progress_callback=lambda x: progress_bar.progress(x)
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Clear the progress indicators
|
| 419 |
+
progress_placeholder.empty()
|
| 420 |
+
status_placeholder.empty()
|
| 421 |
+
|
| 422 |
+
if similarity is None:
|
| 423 |
+
st.error("Error calculating similarity. Please check your task and sensor configuration.")
|
| 424 |
+
return None
|
| 425 |
+
|
| 426 |
+
return similarity
|
| 427 |
+
|
| 428 |
+
except ValueError as e:
|
| 429 |
+
# Clear progress indicators and show error
|
| 430 |
+
progress_placeholder.empty()
|
| 431 |
+
status_placeholder.empty()
|
| 432 |
+
st.error(f"This is not a valid task configuration: {e}")
|
| 433 |
+
return None
|
| 434 |
+
except Exception as e:
|
| 435 |
+
# Clear progress indicators and show error
|
| 436 |
+
progress_placeholder.empty()
|
| 437 |
+
status_placeholder.empty()
|
| 438 |
+
st.error(f"Unexpected error: {e}")
|
| 439 |
+
return None
|
| 440 |
+
|
| 441 |
+
def create_heatmap_matplotlib(data, title, plot_type='input'):
|
| 442 |
+
"""Create a matplotlib heatmap using the library's plotting function"""
|
| 443 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 444 |
+
|
| 445 |
+
# Reshape data to 2D if needed
|
| 446 |
+
if len(data.shape) == 1:
|
| 447 |
+
data = data.reshape(-1, 1) # Convert to 2D column vector
|
| 448 |
+
|
| 449 |
+
plot_similarity_measure(
|
| 450 |
+
ax=ax,
|
| 451 |
+
measure_data=data,
|
| 452 |
+
plot_type=plot_type,
|
| 453 |
+
task_x_name=task1.replace('_', ' ').title(),
|
| 454 |
+
task_y_name=task2.replace('_', ' ').title()
|
| 455 |
+
)
|
| 456 |
+
ax.set_title(title)
|
| 457 |
+
return fig
|
| 458 |
+
|
| 459 |
+
def create_sensor_illustration(selected_sensors):
|
| 460 |
+
"""Create sensor illustration with selected sensors"""
|
| 461 |
+
illustrator = LegIllustration()
|
| 462 |
+
fig = illustrator.draw_illustration(
|
| 463 |
+
highlighted_elements=selected_sensors,
|
| 464 |
+
gait_cycle_sections=st.session_state.get('phase_windows', []) # Default if not set
|
| 465 |
+
)
|
| 466 |
+
# Make figure smaller while maintaining aspect ratio
|
| 467 |
+
# Let the main plotting layout control size
|
| 468 |
+
# fig.set_size_inches(3, 6)
|
| 469 |
+
return fig
|
| 470 |
+
|
| 471 |
+
def create_downloadable_image(fig):
|
| 472 |
+
"""Create a single image from the main figure object."""
|
| 473 |
+
# Remove old logic trying to combine multiple figures
|
| 474 |
+
|
| 475 |
+
# Save the provided figure directly to bytes
|
| 476 |
+
buf = io.BytesIO()
|
| 477 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
| 478 |
+
plt.close(fig) # Close the figure to free memory
|
| 479 |
+
buf.seek(0)
|
| 480 |
+
return buf.getvalue()
|
| 481 |
+
|
| 482 |
+
# Main section after authentication
|
| 483 |
+
if st.session_state.authenticated:
|
| 484 |
+
if generate_button:
|
| 485 |
+
with st.spinner("Doing some math..."):
|
| 486 |
+
# Determine backend parameters based on UI selections
|
| 487 |
+
# These variables (task1, task2, analysis_detail, task1_incline, etc.)
|
| 488 |
+
# are already defined from the sidebar widgets and their values are
|
| 489 |
+
# current when the button is pressed.
|
| 490 |
+
|
| 491 |
+
if analysis_detail == "High Level":
|
| 492 |
+
# For High Level, task names and conditions are None,
|
| 493 |
+
# task1 and task2 from UI are effectively ignored for backend call.
|
| 494 |
+
task1_name_backend = (None, None, None)
|
| 495 |
+
task2_name_backend = (None, None, None)
|
| 496 |
+
elif analysis_detail == "Medium Level":
|
| 497 |
+
# For Medium Level, specific tasks are chosen, but all their conditions.
|
| 498 |
+
# task1, task2 are the selected task strings from the UI.
|
| 499 |
+
# Incline and speed are None to signify "all conditions" for these tasks.
|
| 500 |
+
task1_name_backend = (task1, None, None)
|
| 501 |
+
task2_name_backend = (task2, None, None)
|
| 502 |
+
else: # "Low Level"
|
| 503 |
+
# For Low Level, specific tasks and their specific conditions are used.
|
| 504 |
+
# task1, task2 are task strings.
|
| 505 |
+
# task1_incline, task1_speed, etc., are specific values (or None if "All" was selected for that condition in the UI).
|
| 506 |
+
task1_name_backend = (task1, task1_incline, task1_speed)
|
| 507 |
+
task2_name_backend = (task2, task2_incline, task2_speed)
|
| 508 |
+
|
| 509 |
+
# This will be passed to calculate_overlap_measures, which then passes its .lower() version
|
| 510 |
+
# to calculate_similarity_portrait_abstraction.
|
| 511 |
+
abstraction_level_backend = analysis_detail
|
| 512 |
+
|
| 513 |
+
progress_placeholder = st.empty()
|
| 514 |
+
progress_placeholder.write(f"Calculating input similarity for {task1_name_backend[0] or 'all tasks'} and {task2_name_backend[0] or 'all tasks'}...")
|
| 515 |
+
# Calculate input similarity using determined backend parameters
|
| 516 |
+
input_similarity = calculate_overlap_measures(
|
| 517 |
+
task1=task1_name_backend[0],
|
| 518 |
+
task2=task2_name_backend[0],
|
| 519 |
+
sensors=selected_sensors,
|
| 520 |
+
abstraction_level=abstraction_level_backend,
|
| 521 |
+
task1_incline=task1_name_backend[1],
|
| 522 |
+
task1_speed=task1_name_backend[2],
|
| 523 |
+
task2_incline=task2_name_backend[1],
|
| 524 |
+
task2_speed=task2_name_backend[2],
|
| 525 |
+
time_windows=st.session_state.get('phase_windows', [])
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
output_difference_main = None
|
| 529 |
+
conflict_main = None
|
| 530 |
+
first_output_sensor = None
|
| 531 |
+
|
| 532 |
+
if input_similarity is not None:
|
| 533 |
+
print("Input similarity calculated, proceeding...")
|
| 534 |
+
|
| 535 |
+
# Calculate output difference and conflict for the FIRST selected output
|
| 536 |
+
if (show_output_diff or show_conflict):
|
| 537 |
+
if selected_outputs: # Check if any outputs were selected
|
| 538 |
+
first_output_sensor = selected_outputs[0]
|
| 539 |
+
progress_placeholder.empty()
|
| 540 |
+
progress_placeholder.write(f"Calculating output difference for {first_output_sensor}...")
|
| 541 |
+
|
| 542 |
+
# Calculate DIFFERENCE using determined backend parameters
|
| 543 |
+
output_difference_main = calculate_overlap_measures(
|
| 544 |
+
task1_name_backend[0], task2_name_backend[0],
|
| 545 |
+
[first_output_sensor], abstraction_level_backend, # USE CORRECT VARIABLE
|
| 546 |
+
task1_incline=task1_name_backend[1], task1_speed=task1_name_backend[2],
|
| 547 |
+
task2_incline=task2_name_backend[1], task2_speed=task2_name_backend[2],
|
| 548 |
+
use_output_data=True,
|
| 549 |
+
time_windows=[1] # Output difference usually uses time_window=1
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if output_difference_main is not None and show_conflict:
|
| 553 |
+
progress_placeholder.empty()
|
| 554 |
+
progress_placeholder.write(f"Calculating conflict for {first_output_sensor}...")
|
| 555 |
+
# Conflict is input similarity * output difference
|
| 556 |
+
conflict_main = input_similarity * output_difference_main
|
| 557 |
+
else:
|
| 558 |
+
st.warning("⚠️ Please select at least one output sensor to show Output Differences or Conflict.")
|
| 559 |
+
# Ensure plots dependent on these are skipped
|
| 560 |
+
show_output_diff = False
|
| 561 |
+
show_conflict = False
|
| 562 |
+
|
| 563 |
+
# Clear progress text
|
| 564 |
+
progress_placeholder.empty()
|
| 565 |
+
|
| 566 |
+
# --- Start of New Plotting Layout ---
|
| 567 |
+
st.subheader("Similarity Analysis")
|
| 568 |
+
|
| 569 |
+
fig, axs = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=False) # Adjust figsize as needed
|
| 570 |
+
|
| 571 |
+
# --- Row 0: Output Info / Header ---
|
| 572 |
+
|
| 573 |
+
# Ax[0,0]: Output Leg Illustration
|
| 574 |
+
if (show_output_diff or show_conflict) and first_output_sensor:
|
| 575 |
+
leg_illustrator_out = LegIllustration()
|
| 576 |
+
leg_illustrator_out.draw_illustration(ax=axs[0,0], highlighted_elements=[first_output_sensor])
|
| 577 |
+
axs[0,0].set_title(f"Output: {first_output_sensor.replace('_s_r','').replace('_',' ').title()}")
|
| 578 |
+
else:
|
| 579 |
+
axs[0,0].axis('off')
|
| 580 |
+
|
| 581 |
+
# Ax[0,1]: Task Info Text
|
| 582 |
+
axs[0,1].axis('off') # Turn off axis lines and ticks
|
| 583 |
+
task_info_text = f"{abstraction_level_backend} Abstraction\n" # Use abstraction_level_backend
|
| 584 |
+
task1_display_incline = "All" if task1_name_backend[1] is None else task1_name_backend[1]
|
| 585 |
+
task1_display_speed = "All" if task1_name_backend[2] is None else task1_name_backend[2]
|
| 586 |
+
task2_display_incline = "All" if task2_name_backend[1] is None else task2_name_backend[1]
|
| 587 |
+
task2_display_speed = "All" if task2_name_backend[2] is None else task2_name_backend[2]
|
| 588 |
+
|
| 589 |
+
task1_display_name = task1_name_backend[0] or "All Tasks"
|
| 590 |
+
task2_display_name = task2_name_backend[0] or "All Tasks"
|
| 591 |
+
|
| 592 |
+
task_info_text += f"Task 1 (Y): {task1_display_name} (Incline: {task1_display_incline}, Speed: {task1_display_speed})\n"
|
| 593 |
+
task_info_text += f"Task 2 (X): {task2_display_name} (Incline: {task2_display_incline}, Speed: {task2_display_speed})"
|
| 594 |
+
axs[0,1].text(0.5, 0.5, task_info_text, ha='center', va='center', fontsize=10, wrap=True)
|
| 595 |
+
|
| 596 |
+
# Ax[0,2]: Output Difference Heatmap
|
| 597 |
+
if show_output_diff and output_difference_main is not None:
|
| 598 |
+
plot_similarity_measure(
|
| 599 |
+
ax=axs[0,2],
|
| 600 |
+
measure_data=output_difference_main,
|
| 601 |
+
plot_type='output',
|
| 602 |
+
cbar=True,
|
| 603 |
+
cbar_labels=True,
|
| 604 |
+
fontsize=10,
|
| 605 |
+
task_x_name=task2_name_backend[0], # Pass task names for potential use in plot_similarity_measure
|
| 606 |
+
task_y_name=task1_name_backend[0]
|
| 607 |
+
)
|
| 608 |
+
axs[0,2].set_title("Output Difference")
|
| 609 |
+
else:
|
| 610 |
+
axs[0,2].axis('off')
|
| 611 |
+
|
| 612 |
+
# --- Row 1: Input Info / Conflict ---
|
| 613 |
+
|
| 614 |
+
# Ax[1,0]: Input Leg Illustration
|
| 615 |
+
leg_illustrator_in = LegIllustration()
|
| 616 |
+
leg_illustrator_in.draw_illustration(
|
| 617 |
+
ax=axs[1,0],
|
| 618 |
+
highlighted_elements=selected_sensors,
|
| 619 |
+
gait_cycle_sections=st.session_state.get('phase_windows', [])
|
| 620 |
+
)
|
| 621 |
+
axs[1,0].set_title(f"Input Sensors")
|
| 622 |
+
|
| 623 |
+
# Ax[1,1]: Input Similarity Heatmap
|
| 624 |
+
plot_similarity_measure(
|
| 625 |
+
ax=axs[1,1],
|
| 626 |
+
measure_data=input_similarity,
|
| 627 |
+
plot_type='input',
|
| 628 |
+
cbar=True,
|
| 629 |
+
cbar_labels=True,
|
| 630 |
+
fontsize=10,
|
| 631 |
+
task_x_name=task2_name_backend[0],
|
| 632 |
+
task_y_name=task1_name_backend[0]
|
| 633 |
+
)
|
| 634 |
+
axs[1,1].set_title("Input Similarity")
|
| 635 |
+
|
| 636 |
+
# Ax[1,2]: Conflict Heatmap
|
| 637 |
+
if show_conflict and conflict_main is not None:
|
| 638 |
+
plot_similarity_measure(
|
| 639 |
+
ax=axs[1,2],
|
| 640 |
+
measure_data=conflict_main,
|
| 641 |
+
plot_type='conflict',
|
| 642 |
+
cbar=True,
|
| 643 |
+
cbar_labels=True,
|
| 644 |
+
fontsize=10,
|
| 645 |
+
task_x_name=task2_name_backend[0],
|
| 646 |
+
task_y_name=task1_name_backend[0]
|
| 647 |
+
)
|
| 648 |
+
axs[1,2].set_title("Input-Output Conflict")
|
| 649 |
+
else:
|
| 650 |
+
axs[1,2].axis('off')
|
| 651 |
+
|
| 652 |
+
# Adjust layout and display
|
| 653 |
+
plt.tight_layout(pad=1.5)
|
| 654 |
+
st.pyplot(fig)
|
| 655 |
+
|
| 656 |
+
# --- End of New Plotting Layout ---
|
| 657 |
+
|
| 658 |
+
# Download button using the new main figure
|
| 659 |
+
try:
|
| 660 |
+
image_data = create_downloadable_image(fig)
|
| 661 |
+
st.download_button(
|
| 662 |
+
label="Download Plot Layout",
|
| 663 |
+
data=image_data,
|
| 664 |
+
file_name="similarity_layout.png",
|
| 665 |
+
mime="image/png",
|
| 666 |
+
)
|
| 667 |
+
except Exception as e:
|
| 668 |
+
st.error(f"Failed to create downloadable image: {e}")
|
| 669 |
+
|
| 670 |
+
# Handle case where input similarity calculation failed
|
| 671 |
+
else:
|
| 672 |
+
st.error("Error calculating input similarity. Please check your task and sensor configuration.")
|
| 673 |
+
|
| 674 |
+
# Existing logic for marginals if needed (outside the main 2x3 grid)
|
| 675 |
+
if show_marginals:
|
| 676 |
+
st.markdown("--- Marginal Distributions ---")
|
| 677 |
+
# Placeholder: Add logic to plot marginal distributions if required.
|
| 678 |
+
# This was previously outside the main conditional blocks.
|
| 679 |
+
st.write("(Marginal distribution plotting not implemented in this layout yet)")
|
data_tools.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def generate_task_info_name(walking_type, incline, walking_speed):
|
| 4 |
+
"""
|
| 5 |
+
Generate a standardized 'task_info' name string based on the walking type, incline, and walking speed.
|
| 6 |
+
|
| 7 |
+
The returned string follows the format:
|
| 8 |
+
For level walking: "{walking_type}_{formatted_speed}"
|
| 9 |
+
For incline/decline: "{walking_type}_{incline}_deg_{formatted_speed}"
|
| 10 |
+
where:
|
| 11 |
+
- walking_type is expected to be one of:
|
| 12 |
+
'decline_walking', 'level_walking', or 'incline_walking'
|
| 13 |
+
- incline is a numerical value indicating ground incline in degrees.
|
| 14 |
+
- formatted_speed is a walking speed string, e.g., 's0x8' for 0.8 m/s, 's1' for 1.0 m/s, 's1x2' for 1.2 m/s.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
walking_type (str): The type of walking. Expected values: 'decline_walking', 'level_walking', or 'incline_walking'.
|
| 18 |
+
incline (float or int): Ground incline in degrees.
|
| 19 |
+
walking_speed (float): Walking speed in m/s.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: The generated task_info name.
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
>>> generate_task_info_name('decline_walking', -5, 0.8)
|
| 26 |
+
'decline_walking_5_deg_s0x8'
|
| 27 |
+
>>> generate_task_info_name('level_walking', 0, 1.0)
|
| 28 |
+
'level_walking_s1'
|
| 29 |
+
>>> generate_task_info_name('incline_walking', 5, 1.2)
|
| 30 |
+
'incline_walking_5_deg_s1x2'
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
if 'stair' in walking_type:
|
| 34 |
+
return walking_type
|
| 35 |
+
|
| 36 |
+
if walking_type != 'level_walking' and (np.isnan(incline) or incline is None):
|
| 37 |
+
print(f"Incline is NaN for task {walking_type} {incline} {walking_speed}")
|
| 38 |
+
return None
|
| 39 |
+
if np.isnan(walking_speed) or walking_speed is None:
|
| 40 |
+
print(f"Walking speed is NaN for task {walking_type} {incline} {walking_speed}")
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
def format_walking_speed(speed):
|
| 44 |
+
# Map known speeds to specific string representations
|
| 45 |
+
mapping = {0.8: "s0x8", 1.0: "s1", 1.2: "s1x2"}
|
| 46 |
+
# Use the mapping if available, otherwise default to a generic string
|
| 47 |
+
return mapping.get(speed, f"s{speed}")
|
| 48 |
+
|
| 49 |
+
# Format walking speed
|
| 50 |
+
speed_str = format_walking_speed(walking_speed)
|
| 51 |
+
|
| 52 |
+
# For level walking, return simpler format without incline info
|
| 53 |
+
if walking_type == 'level_walking':
|
| 54 |
+
return f"{walking_type}_{speed_str}"
|
| 55 |
+
|
| 56 |
+
# Format incline: if the value is an integer, do not show decimals.
|
| 57 |
+
if incline == int(incline):
|
| 58 |
+
incline_str = str(abs(int(incline)))
|
| 59 |
+
else:
|
| 60 |
+
incline_str = str(abs(incline))
|
| 61 |
+
|
| 62 |
+
# Build and return the standardized task_info string for incline/decline walking
|
| 63 |
+
task_info = f"{walking_type}_{incline_str}_deg_{speed_str}_m_s"
|
| 64 |
+
return task_info
|
multivariate_gaussian_overlap.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
""""
|
| 3 |
+
This script is meant to analyze how similar sensor inputs are to each other for
|
| 4 |
+
different tasks and potentially different subjects
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
Note: Task 1 represents the y axis in the heatmap.
|
| 9 |
+
Task 2 represents the x axis in the heatmap.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import itertools
|
| 15 |
+
from functools import partial
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from typing import Optional, Union, Callable
|
| 18 |
+
import pickle
|
| 19 |
+
from data_tools import generate_task_info_name
|
| 20 |
+
import os
|
| 21 |
+
# from data_loader import total_data # Removed as total_data is no longer used directly
|
| 22 |
+
from functools import lru_cache
|
| 23 |
+
from scipy.stats import norm
|
| 24 |
+
# Import LOW_LEVEL_TASKS from config
|
| 25 |
+
from config import LOW_LEVEL_TASKS
|
| 26 |
+
# Add this at the top of your file, after the imports
|
| 27 |
+
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
|
| 28 |
+
|
| 29 |
+
# ADD IMPORTS
|
| 30 |
+
import gzip
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
# Load significance thresholds if available
|
| 34 |
+
try:
|
| 35 |
+
SIGNIFICANCE_THRESHOLDS = pd.read_csv('significance_thresholds.csv')
|
| 36 |
+
except FileNotFoundError:
|
| 37 |
+
SIGNIFICANCE_THRESHOLDS = None
|
| 38 |
+
|
| 39 |
+
# REMOVE task_subject_indices_file loading
|
| 40 |
+
# Load file with task subject indices
|
| 41 |
+
# try:
|
| 42 |
+
# with open('cached_data/task_subject_indices.pkl', 'rb') as f:
|
| 43 |
+
# task_subject_indices_file = pickle.load(f)
|
| 44 |
+
# except FileNotFoundError:
|
| 45 |
+
# print("Did not find any task subject indices. Current working directory: ",
|
| 46 |
+
# os.getcwd())
|
| 47 |
+
# task_subject_indices_file = None
|
| 48 |
+
|
| 49 |
+
# ADD HOSTING_STATS loading
|
| 50 |
+
HOSTING_STATS_PATH = Path("cached_data/precalculated_stats.pkl.gz") # Assumes script is run from task-similarity-analysis/ or it's in PYTHONPATH
|
| 51 |
+
try:
|
| 52 |
+
with gzip.open(HOSTING_STATS_PATH, "rb") as f:
|
| 53 |
+
HOSTING_STATS = pickle.load(f)
|
| 54 |
+
if HOSTING_STATS is None: # Should not happen if pickle.load succeeds
|
| 55 |
+
print(f"Warning: {HOSTING_STATS_PATH} loaded as None.")
|
| 56 |
+
# else:
|
| 57 |
+
# print(f"Successfully loaded {HOSTING_STATS_PATH}")
|
| 58 |
+
except FileNotFoundError:
|
| 59 |
+
print(f"ERROR: Could not load {HOSTING_STATS_PATH}. Please run preprocess_data_for_hosting.py first.")
|
| 60 |
+
HOSTING_STATS = None
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error loading {HOSTING_STATS_PATH}: {e}")
|
| 63 |
+
HOSTING_STATS = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def calculate_task_statistics(
|
| 67 |
+
task_info_with_subject: tuple,
|
| 68 |
+
selected_sensors: list[str],
|
| 69 |
+
time_window_offsets: list[int], # Represents phase_windows from dashboard e.g. [1], or [1,2,3]
|
| 70 |
+
verbose: bool = False
|
| 71 |
+
) -> Optional[tuple[np.ndarray, np.ndarray, int]]:
|
| 72 |
+
"""Pre-compute task statistics for a given task and sensor configuration using HOSTING_STATS.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
task_info_with_subject: Tuple of (task_name, incline, speed, subject)
|
| 76 |
+
selected_sensors: List of sensor names to use
|
| 77 |
+
time_window_offsets: List of 1-indexed phase windows (e.g., [1] for current, [1,2] for current and t-1).
|
| 78 |
+
Effectively, 1 -> roll_offset 0, 2 -> roll_offset 1, etc.
|
| 79 |
+
verbose: Whether to print verbose output
|
| 80 |
+
Returns:
|
| 81 |
+
Tuple of (final_means_array, final_covs_array, n_samples) if stats exist, None otherwise.
|
| 82 |
+
final_means_array shape: (150, len(selected_sensors) * num_actual_time_windows)
|
| 83 |
+
final_covs_array shape: (150, combined_num_features, combined_num_features) (diagonal)
|
| 84 |
+
n_samples: number of samples used for pre-calculation.
|
| 85 |
+
"""
|
| 86 |
+
if HOSTING_STATS is None:
|
| 87 |
+
if verbose:
|
| 88 |
+
print("Error: HOSTING_STATS not loaded. Cannot calculate task statistics.")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
subject_stats = HOSTING_STATS[task_info_with_subject]
|
| 93 |
+
means_all_features_at_t0 = subject_stats['means'] # (150, num_all_features_precalculated)
|
| 94 |
+
variances_all_features_at_t0 = subject_stats['variances'] # (150, num_all_features_precalculated)
|
| 95 |
+
n_samples = subject_stats['n_samples']
|
| 96 |
+
feature_order_at_preprocessing = subject_stats['feature_order']
|
| 97 |
+
except KeyError:
|
| 98 |
+
if verbose:
|
| 99 |
+
print(f"Warning: Statistics for {task_info_with_subject} not found in HOSTING_STATS. Skipping.")
|
| 100 |
+
return None
|
| 101 |
+
except TypeError: # If HOSTING_STATS is None or not a dict
|
| 102 |
+
if verbose:
|
| 103 |
+
print(f"Warning: HOSTING_STATS is not a valid dictionary for {task_info_with_subject}. Skipping.")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
sensor_indices_in_precalc_order = [feature_order_at_preprocessing.index(s) for s in selected_sensors]
|
| 109 |
+
except ValueError as e:
|
| 110 |
+
if verbose:
|
| 111 |
+
print(f"Warning: One or more selected_sensors not found in preprocessed feature_order for {task_info_with_subject}: {e}. Selected: {selected_sensors}, Available: {feature_order_at_preprocessing}. Skipping.")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
means_selected_current_phase = means_all_features_at_t0[:, sensor_indices_in_precalc_order] # (150, len(selected_sensors))
|
| 115 |
+
variances_selected_current_phase = variances_all_features_at_t0[:, sensor_indices_in_precalc_order] # (150, len(selected_sensors))
|
| 116 |
+
|
| 117 |
+
# Determine actual roll offsets based on time_window_offsets from dashboard
|
| 118 |
+
# time_window_offsets (phase_windows from dashboard): [1] means current data (roll offset 0)
|
| 119 |
+
# [1, 2, 3] means current (offset 0), t-1 (offset 1), t-2 (offset 2)
|
| 120 |
+
|
| 121 |
+
actual_roll_offsets_to_apply = [] # stores 0 for current, 1 for t-1, etc.
|
| 122 |
+
|
| 123 |
+
# Ensure time_window_offsets is a list, default to [1] if empty or None (current time only)
|
| 124 |
+
current_time_window_offsets = time_window_offsets if time_window_offsets else [1]
|
| 125 |
+
|
| 126 |
+
for w in sorted(list(set(current_time_window_offsets))): # e.g., w from [1, 2, 3]
|
| 127 |
+
if w >= 1: # Dashboard sends 1-indexed windows
|
| 128 |
+
actual_roll_offsets_to_apply.append(w - 1) # Convert to 0-indexed roll offset
|
| 129 |
+
|
| 130 |
+
# Ensure actual_roll_offsets_to_apply is sorted and unique, typically [0], or [0,1,2] etc.
|
| 131 |
+
actual_roll_offsets_to_apply = sorted(list(set(actual_roll_offsets_to_apply)))
|
| 132 |
+
if not actual_roll_offsets_to_apply: # Should not happen if dashboard sends at least [1]
|
| 133 |
+
if verbose: print(f"Warning: No valid roll offsets from time_window_offsets {time_window_offsets}, defaulting to current time (offset 0).")
|
| 134 |
+
actual_roll_offsets_to_apply = [0]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
final_means_list = []
|
| 138 |
+
final_variances_list = []
|
| 139 |
+
|
| 140 |
+
for roll_offset in actual_roll_offsets_to_apply:
|
| 141 |
+
if roll_offset == 0:
|
| 142 |
+
final_means_list.append(means_selected_current_phase.copy())
|
| 143 |
+
final_variances_list.append(variances_selected_current_phase.copy())
|
| 144 |
+
elif roll_offset > 0:
|
| 145 |
+
rolled_means = np.roll(means_selected_current_phase, shift=roll_offset, axis=0) # roll along phase axis
|
| 146 |
+
rolled_variances = np.roll(variances_selected_current_phase, shift=roll_offset, axis=0)
|
| 147 |
+
final_means_list.append(rolled_means)
|
| 148 |
+
final_variances_list.append(rolled_variances)
|
| 149 |
+
# else: negative roll_offset could mean future data, not typically used here.
|
| 150 |
+
|
| 151 |
+
if not final_means_list:
|
| 152 |
+
if verbose:
|
| 153 |
+
print(f"Critical Error: final_means_list is empty after processing time_window_offsets for {task_info_with_subject}. This should not happen.")
|
| 154 |
+
return None # Or raise error
|
| 155 |
+
|
| 156 |
+
final_means_array = np.concatenate(final_means_list, axis=1) # Concatenate along feature axis
|
| 157 |
+
final_variances_array = np.concatenate(final_variances_list, axis=1) # Concatenate along feature axis
|
| 158 |
+
|
| 159 |
+
num_combined_features = final_means_array.shape[1]
|
| 160 |
+
final_covs_array = np.zeros((150, num_combined_features, num_combined_features))
|
| 161 |
+
for i in range(150): # For each of the 150 phase points
|
| 162 |
+
final_covs_array[i] = np.diag(final_variances_array[i]) # Variances are for the combined features
|
| 163 |
+
|
| 164 |
+
return final_means_array, final_covs_array, n_samples
|
| 165 |
+
|
| 166 |
+
def apply_binary_threshold(portrait, n_samples_1, n_samples_2):
|
| 167 |
+
# Get minimum number of samples for this task pair
|
| 168 |
+
min_samples = min(n_samples_1, n_samples_2)
|
| 169 |
+
# Clip to max available threshold sample size
|
| 170 |
+
min_samples = min(min_samples, len(SIGNIFICANCE_THRESHOLDS))
|
| 171 |
+
# Get threshold for this sample size
|
| 172 |
+
threshold = SIGNIFICANCE_THRESHOLDS.iloc[min_samples-1]['threshold_95']
|
| 173 |
+
# Apply binary threshold
|
| 174 |
+
portrait[portrait < threshold] = 0
|
| 175 |
+
portrait[portrait >= threshold] = 1
|
| 176 |
+
return portrait
|
| 177 |
+
|
| 178 |
+
def process_pair(task_pair, task_1_stats, task_2_stats,
|
| 179 |
+
binary_threshold=False,
|
| 180 |
+
match_subjects: bool = False,
|
| 181 |
+
biomechanical_difference: bool = False):
|
| 182 |
+
"""Process a single task pair using pre-computed statistics (vectorized path only)."""
|
| 183 |
+
task1, task2 = task_pair
|
| 184 |
+
|
| 185 |
+
# Initialize list to collect portraits for all subject combinations
|
| 186 |
+
subject_pairs = []
|
| 187 |
+
portraits = []
|
| 188 |
+
|
| 189 |
+
# Handle subject matching option
|
| 190 |
+
if match_subjects:
|
| 191 |
+
# Only compare the same subjects between tasks
|
| 192 |
+
common_subjects = set(task_1_stats[task1].keys()) & set(task_2_stats[task2].keys())
|
| 193 |
+
for subject in common_subjects:
|
| 194 |
+
subject_pairs.append((subject, subject))
|
| 195 |
+
means1, covs1, n_samples1 = task_1_stats[task1][subject]
|
| 196 |
+
means2, covs2, n_samples2 = task_2_stats[task2][subject]
|
| 197 |
+
|
| 198 |
+
# Calculate portrait for this subject - ALWAYS USE VECTORIZED PATH
|
| 199 |
+
portrait = vectorized_overlap_mine(means1, covs1, means2, covs2,
|
| 200 |
+
biomechanical_difference=biomechanical_difference)
|
| 201 |
+
|
| 202 |
+
if binary_threshold:
|
| 203 |
+
portrait = apply_binary_threshold(portrait, n_samples1, n_samples2)
|
| 204 |
+
|
| 205 |
+
portraits.append(portrait)
|
| 206 |
+
else:
|
| 207 |
+
# Compare all possible subject combinations between the two tasks
|
| 208 |
+
for subject1, stats1 in task_1_stats[task1].items():
|
| 209 |
+
for subject2, stats2 in task_2_stats[task2].items():
|
| 210 |
+
subject_pairs.append((subject1, subject2))
|
| 211 |
+
means1, covs1, n_samples1 = stats1
|
| 212 |
+
means2, covs2, n_samples2 = stats2
|
| 213 |
+
|
| 214 |
+
# Calculate portrait for this subject pair - ALWAYS USE VECTORIZED PATH
|
| 215 |
+
portrait = vectorized_overlap_mine(means1, covs1, means2, covs2,
|
| 216 |
+
biomechanical_difference=biomechanical_difference)
|
| 217 |
+
|
| 218 |
+
if binary_threshold:
|
| 219 |
+
portrait = apply_binary_threshold(portrait, n_samples1, n_samples2)
|
| 220 |
+
|
| 221 |
+
portraits.append(portrait)
|
| 222 |
+
|
| 223 |
+
# If no valid portraits were created, return None
|
| 224 |
+
if len(portraits) == 0:
|
| 225 |
+
return None
|
| 226 |
+
elif 'stair' in task1[0] or 'stair' in task2[0]:
|
| 227 |
+
pass
|
| 228 |
+
return portraits, subject_pairs
|
| 229 |
+
|
| 230 |
+
def vectorized_overlap_mine(means1, covs1, means2, covs2, tol=1e-12,
|
| 231 |
+
biomechanical_difference=False):
|
| 232 |
+
"""
|
| 233 |
+
Vectorized version of the 'mine' method overlap measure computation.
|
| 234 |
+
Assumes covariance matrices (covs1, covs2) are diagonal.
|
| 235 |
+
|
| 236 |
+
Parameters:
|
| 237 |
+
means1 (np.ndarray): Array of means for task 1 with shape (150, d).
|
| 238 |
+
covs1 (np.ndarray): Array of covariance matrices for task 1 with shape (150, d, d). Must be diagonal.
|
| 239 |
+
means2 (np.ndarray): Array of means for task 2 with shape (150, d).
|
| 240 |
+
covs2 (np.ndarray): Array of covariance matrices for task 2 with shape (150, d, d). Must be diagonal.
|
| 241 |
+
tol (float): Tolerance value to determine if a matrix is singular.
|
| 242 |
+
biomechanical_difference (bool): If True, apply output-difference filtering per
|
| 243 |
+
the method: negligible–negligible (D* = 0), amplitude (D* = D), or sign reversal
|
| 244 |
+
(D* = D × P_diff_sign).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
np.ndarray: Overlap measures as a (150, 150) array.
|
| 248 |
+
"""
|
| 249 |
+
d = means1.shape[1]
|
| 250 |
+
# Pairwise mean differences: (150,150,d)
|
| 251 |
+
diff = means1[:, None, :] - means2[None, :, :]
|
| 252 |
+
# Combined covariances: (150,150,d,d)
|
| 253 |
+
S = covs1[:, None, :, :] + covs2[None, :, :, :]
|
| 254 |
+
# output
|
| 255 |
+
overlap = np.zeros((150, 150))
|
| 256 |
+
|
| 257 |
+
# --- compute raw overlap ---
|
| 258 |
+
# Logic is now only for diagonal covariance
|
| 259 |
+
# S_diag was originally `np.diagonal(S, axis1=2, axis2=3)`
|
| 260 |
+
# Since S = covs1 + covs2 and covs are diagonal, S is diagonal.
|
| 261 |
+
# S_diag elements are just var1 + var2 for each feature.
|
| 262 |
+
# covs1 and covs2 are (150, d, d), but diagonal. We can extract their diagonals first.
|
| 263 |
+
vars1 = np.diagonal(covs1, axis1=1, axis2=2) # Shape (150, d)
|
| 264 |
+
vars2 = np.diagonal(covs2, axis1=1, axis2=2) # Shape (150, d)
|
| 265 |
+
|
| 266 |
+
# Pairwise sum of variances: (150, 150, d)
|
| 267 |
+
# vars1[:, None, :] gives (150, 1, d)
|
| 268 |
+
# vars2[None, :, :] gives (1, 150, d)
|
| 269 |
+
# Broadcasting sum_vars to (150, 150, d)
|
| 270 |
+
sum_vars = vars1[:, None, :] + vars2[None, :, :]
|
| 271 |
+
|
| 272 |
+
non_singular = np.all(sum_vars > tol, axis=2) # Check if all diagonal elements of (cov1+cov2) are > tol for each pair
|
| 273 |
+
|
| 274 |
+
if np.any(non_singular):
|
| 275 |
+
# Compute the quadratic form efficiently for diagonal matrices
|
| 276 |
+
# diff is (150,150,d). diff[non_singular] is (N,d) where N is number of non_singular pairs.
|
| 277 |
+
# sum_vars[non_singular] is (N,d).
|
| 278 |
+
quad_form_terms = diff[non_singular]**2 / sum_vars[non_singular] # (N, d)
|
| 279 |
+
quad_form = np.sum(quad_form_terms, axis=1) # (N,)
|
| 280 |
+
|
| 281 |
+
# Compute the overlap measure with underflow protection
|
| 282 |
+
max_exp = 20.0 # Maximum value before underflow in np.exp
|
| 283 |
+
|
| 284 |
+
overlap_values = np.zeros_like(quad_form)
|
| 285 |
+
|
| 286 |
+
# Indices where exponent is too large
|
| 287 |
+
large_exp_mask = (0.5 * quad_form) > max_exp
|
| 288 |
+
|
| 289 |
+
# Indices where exponent is not too large
|
| 290 |
+
non_large_exp_mask = ~large_exp_mask
|
| 291 |
+
|
| 292 |
+
overlap_values[non_large_exp_mask] = np.exp(-0.5 * quad_form[non_large_exp_mask])
|
| 293 |
+
# For large exponents, overlap_values remains 0 (default)
|
| 294 |
+
|
| 295 |
+
# Assign computed values to the overlap array
|
| 296 |
+
i_idx, j_idx = np.where(non_singular)
|
| 297 |
+
overlap[i_idx, j_idx] = overlap_values
|
| 298 |
+
|
| 299 |
+
# --- output‐difference filtering (vectorized) ---
|
| 300 |
+
if biomechanical_difference:
|
| 301 |
+
negligible_threshold = 0.1
|
| 302 |
+
ampable_threshold = 0.2
|
| 303 |
+
|
| 304 |
+
# optimized for scalar case (d=1)
|
| 305 |
+
std1 = np.sqrt(covs1[:, 0, 0])
|
| 306 |
+
std2 = np.sqrt(covs2[:, 0, 0])
|
| 307 |
+
ci_lo1 = means1[:, 0] - 1.96 * std1
|
| 308 |
+
ci_hi1 = means1[:, 0] + 1.96 * std1
|
| 309 |
+
ci_lo2 = means2[:, 0] - 1.96 * std2
|
| 310 |
+
ci_hi2 = means2[:, 0] + 1.96 * std2
|
| 311 |
+
|
| 312 |
+
z1 = means1[:, 0] / np.maximum(std1, tol)
|
| 313 |
+
z2 = means2[:, 0] / np.maximum(std2, tol)
|
| 314 |
+
Ppos1 = norm.cdf(z1); Pneg1 = 1.0 - Ppos1
|
| 315 |
+
Ppos2 = norm.cdf(z2); Pneg2 = 1.0 - Ppos2
|
| 316 |
+
|
| 317 |
+
negligible1 = (ci_lo1 >= -negligible_threshold) & (ci_hi1 <= negligible_threshold)
|
| 318 |
+
negligible2 = (ci_lo2 >= -negligible_threshold) & (ci_hi2 <= negligible_threshold)
|
| 319 |
+
ampable1 = np.abs(means1[:, 0]) > ampable_threshold
|
| 320 |
+
ampable2 = np.abs(means2[:, 0]) > ampable_threshold
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
i_idx, j_idx = np.nonzero(non_singular)
|
| 324 |
+
if i_idx.size > 0:
|
| 325 |
+
# negligible–negligible → zero overlap
|
| 326 |
+
m0 = negligible1[i_idx] & negligible2[j_idx]
|
| 327 |
+
# Since we calculate similarity, we set the similarity to 1
|
| 328 |
+
# which is turns makes the difference 0
|
| 329 |
+
overlap[i_idx[m0], j_idx[m0]] = 1.0
|
| 330 |
+
|
| 331 |
+
# amplitude conflicts → skip modification
|
| 332 |
+
m1 = (negligible1[i_idx] & ampable2[j_idx]) \
|
| 333 |
+
| (negligible2[j_idx] & ampable1[i_idx])
|
| 334 |
+
|
| 335 |
+
# sign-reversal scaling for the rest
|
| 336 |
+
m2 = ~(m0 | m1)
|
| 337 |
+
if np.any(m2):
|
| 338 |
+
with np.errstate(all='ignore'):
|
| 339 |
+
Pdiff = (
|
| 340 |
+
Ppos1[i_idx[m2]] * Pneg2[j_idx[m2]] +
|
| 341 |
+
Pneg1[i_idx[m2]] * Ppos2[j_idx[m2]]
|
| 342 |
+
)
|
| 343 |
+
# Since we calculate similarity, we have to scale the overlap
|
| 344 |
+
# and then convert it back to a similarity measure
|
| 345 |
+
overlap[i_idx[m2], j_idx[m2]] = 1-(1-overlap[i_idx[m2], j_idx[m2]])*Pdiff
|
| 346 |
+
|
| 347 |
+
# clamp to [0,1]
|
| 348 |
+
np.clip(overlap, 0.0, 1.0, out=overlap)
|
| 349 |
+
return overlap
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def calculate_similarity_portrait_abstraction(sensors: Union[tuple, list], time_window: Union[int, list[int]] = 1,
|
| 353 |
+
task1_name: Union[tuple, str, None] = None,
|
| 354 |
+
task2_name: Union[tuple, str, None] = None,
|
| 355 |
+
verbose: bool = False,
|
| 356 |
+
n_jobs: int = -1, # n_jobs not used currently with sequential loop
|
| 357 |
+
return_all_portraits: bool = False, # return_all_portraits part of results_dict, not directly handled by agg
|
| 358 |
+
progress_callback: Optional[Callable[[float], None]] = None,
|
| 359 |
+
binary_threshold: bool = False,
|
| 360 |
+
output_difference: bool = False, # Renamed from use_output_data in dashboard
|
| 361 |
+
match_subjects: bool = False,
|
| 362 |
+
save_results_dict: bool = False,
|
| 363 |
+
biomechanical_difference: bool = False,
|
| 364 |
+
abstraction_level: Optional[str] = None): # abstraction_level passed from dashboard
|
| 365 |
+
"""
|
| 366 |
+
Calculate input similarity between two tasks using multivariate Gaussian
|
| 367 |
+
based on a given abstraction level. Uses pre-calculated HOSTING_STATS.
|
| 368 |
+
Always uses the vectorized 'mine' overlap calculation method.
|
| 369 |
+
|
| 370 |
+
High level: Is there any conflict?
|
| 371 |
+
At this abstraction level, we want to understand if there is any
|
| 372 |
+
conflict for a given sensor configuration. Calculate the similarity
|
| 373 |
+
portrait for each task at the lowest level of abstraction and then
|
| 374 |
+
aggregate the results using the agg_method function (maximum of each
|
| 375 |
+
individual similarity portrait). The total similarity portrait is then
|
| 376 |
+
transposed and maxed back so that it is symmetric.
|
| 377 |
+
|
| 378 |
+
Medium/Low level: Is there any similarity between tasks with optional specificity?
|
| 379 |
+
At this abstraction level, we want to understand similarity between tasks
|
| 380 |
+
with optional specificity for incline and speed. For example:
|
| 381 |
+
- Compare all incline 5° to stairs up: task1_name=('incline', 5, None)
|
| 382 |
+
- Compare specific tasks: task1_name=('incline', 5, 0.8)
|
| 383 |
+
- Compare general tasks: task1_name=('incline', None, None)
|
| 384 |
+
|
| 385 |
+
Format: task_name = (task:str, incline:float|None, speed:float|None)
|
| 386 |
+
Setting incline/speed to None means match any value for that parameter.
|
| 387 |
+
|
| 388 |
+
At each level, the similarity portrait is calculated for each subject and
|
| 389 |
+
leg seperately and then aggregated using the agg_method function.
|
| 390 |
+
|
| 391 |
+
Arguments:
|
| 392 |
+
----------
|
| 393 |
+
sensors: tuple or list
|
| 394 |
+
List of sensor names to use for similarity calculation
|
| 395 |
+
time_window: int or list, optional
|
| 396 |
+
List of 1-indexed phase windows (e.g., [1] for current, [1,2] for current and t-1).
|
| 397 |
+
Passed as time_window_offsets to calculate_task_statistics.
|
| 398 |
+
task1_name: tuple, optional
|
| 399 |
+
For low level: tuple of (task_name:str, incline:float|None,
|
| 400 |
+
speed:float|None)
|
| 401 |
+
Set incline/speed to None to match any value
|
| 402 |
+
task2_name: tuple, optional
|
| 403 |
+
Same format as task1_name
|
| 404 |
+
verbose: bool, optional
|
| 405 |
+
Whether to show progress bar during calculations (default=False)
|
| 406 |
+
n_jobs: int, optional
|
| 407 |
+
Number of CPU cores to use. -1 means use all available cores (default=-1). Currently not used.
|
| 408 |
+
|
| 409 |
+
return_all_portraits: bool, optional
|
| 410 |
+
If True, returns the full results_dict. If False (default), returns only the aggregated portrait.
|
| 411 |
+
binary_threshold: bool, optional
|
| 412 |
+
If True, threshold each low-level similarity portrait based on the 95%
|
| 413 |
+
confidence interval. Values below the threshold are set to 0, above to 1.
|
| 414 |
+
(default=False)
|
| 415 |
+
output_difference: bool, optional
|
| 416 |
+
If True, return the difference between the similarity portrait and 1
|
| 417 |
+
(default=False)
|
| 418 |
+
match_subjects: bool, optional
|
| 419 |
+
If True, match subjects between tasks (default=False)
|
| 420 |
+
save_results_dict: bool, optional
|
| 421 |
+
If True, save the results_dict to a file. Default False.
|
| 422 |
+
abstraction_level: str, optional
|
| 423 |
+
The abstraction level ("high", "medium", "low"). Used for logging/context, not for filtering within this function anymore.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
# time_window is expected to be a list of 1-indexed phase windows from dashboard, e.g., [1], [1,2,3]
|
| 427 |
+
# It will be passed as time_window_offsets to calculate_task_statistics
|
| 428 |
+
# If it's an int, convert to list, though dashboard should pass list.
|
| 429 |
+
if isinstance(time_window, int):
|
| 430 |
+
# This interpretation might need review based on how dashboard calls it.
|
| 431 |
+
# If time_window=1 means "current only", it should become [1].
|
| 432 |
+
# If time_window=3 means "current, t-1, t-2", it should become [1,2,3].
|
| 433 |
+
# Let's assume int means number of windows starting from current.
|
| 434 |
+
time_window_offsets_list = list(range(1, time_window + 1)) if time_window > 0 else [1]
|
| 435 |
+
elif isinstance(time_window, list):
|
| 436 |
+
time_window_offsets_list = time_window if time_window else [1] # Default to [1] if empty list
|
| 437 |
+
else: # Default to current time if malformed
|
| 438 |
+
time_window_offsets_list = [1]
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# MOVED TO config.py: low_level_tasks
|
| 442 |
+
unique_subjects_from_config_or_data = []
|
| 443 |
+
if HOSTING_STATS:
|
| 444 |
+
# Try to infer subjects from HOSTING_STATS keys if needed, though LOW_LEVEL_TASKS and iterating subjects is primary
|
| 445 |
+
all_subjects_in_stats = set()
|
| 446 |
+
for k_tuple in HOSTING_STATS.keys():
|
| 447 |
+
if isinstance(k_tuple, tuple) and len(k_tuple) == 4: # (task, inc, spd, subject)
|
| 448 |
+
all_subjects_in_stats.add(k_tuple[3])
|
| 449 |
+
unique_subjects_from_config_or_data = sorted(list(all_subjects_in_stats))
|
| 450 |
+
if not unique_subjects_from_config_or_data and verbose:
|
| 451 |
+
print("Warning: Could not infer subjects from HOSTING_STATS keys.")
|
| 452 |
+
|
| 453 |
+
# If subjects could not be inferred or HOSTING_STATS is None, this part will be problematic.
|
| 454 |
+
# The original code used total_data['subject'].unique(). This is no longer available.
|
| 455 |
+
# For now, let's assume LOW_LEVEL_TASKS and iterating these subjects is the way.
|
| 456 |
+
# The `preprocess_data_for_hosting.py` iterates `unique_subjects = total_data['subject'].unique()`.
|
| 457 |
+
# So `HOSTING_STATS` keys will contain all subjects that had data.
|
| 458 |
+
# `unique_subjects_from_config_or_data` should capture these.
|
| 459 |
+
# If this list is empty and HOSTING_STATS is populated, it means keys are not as expected.
|
| 460 |
+
|
| 461 |
+
# Fallback if subjects list is empty but stats exist (e.g. if keys are not tuples of 4)
|
| 462 |
+
# This part needs robust subject discovery if not relying on a predefined list.
|
| 463 |
+
# For now, assume unique_subjects_from_config_or_data is sufficient or that errors will propagate if it's empty.
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# Pre-compute statistics for all valid tasks
|
| 467 |
+
if verbose or progress_callback:
|
| 468 |
+
print("Pre-computing task statistics...")
|
| 469 |
+
if progress_callback:
|
| 470 |
+
progress_callback(0.1) # Show initial progress
|
| 471 |
+
|
| 472 |
+
task_1_stats = {}
|
| 473 |
+
task_2_stats = {}
|
| 474 |
+
|
| 475 |
+
# Determine subjects to iterate over.
|
| 476 |
+
# If unique_subjects_from_config_or_data is empty, this loop won't run correctly.
|
| 477 |
+
subjects_to_iterate = unique_subjects_from_config_or_data
|
| 478 |
+
if not subjects_to_iterate and verbose:
|
| 479 |
+
print("Warning: subjects_to_iterate is empty. Statistics pre-computation might be skipped or fail.")
|
| 480 |
+
# Potentially, if LOW_LEVEL_TASKS is what defines the scope, and subjects are per task_config in HOSTING_STATS
|
| 481 |
+
# then subject iteration should be nested inside task_config iteration if HOSTING_STATS keys are task_config specific.
|
| 482 |
+
# The current `preprocess_data_for_hosting.py` creates keys `(task_config_tuple, subject_str)`.
|
| 483 |
+
|
| 484 |
+
total_combinations_to_calc_stats_for = 0
|
| 485 |
+
tasks_for_stats_calc_t1 = []
|
| 486 |
+
tasks_for_stats_calc_t2 = []
|
| 487 |
+
|
| 488 |
+
for task_info_cfg in LOW_LEVEL_TASKS: # task_info_cfg is (task_name, incline, speed)
|
| 489 |
+
# Filter based on task1_name and task2_name (from dashboard selection)
|
| 490 |
+
task1_str, task1_incline, task1_speed = task1_name if task1_name else (None, None, None)
|
| 491 |
+
task2_str, task2_incline, task2_speed = task2_name if task2_name else (None, None, None)
|
| 492 |
+
|
| 493 |
+
task_1_match = (task1_str is None or task1_str == task_info_cfg[0]) and \
|
| 494 |
+
(task1_incline is None or task1_incline == task_info_cfg[1]) and \
|
| 495 |
+
(task1_speed is None or task1_speed == task_info_cfg[2])
|
| 496 |
+
|
| 497 |
+
task_2_match = (task2_str is None or task2_str == task_info_cfg[0]) and \
|
| 498 |
+
(task2_incline is None or task2_incline == task_info_cfg[1]) and \
|
| 499 |
+
(task2_speed is None or task2_speed == task_info_cfg[2])
|
| 500 |
+
|
| 501 |
+
if task_1_match:
|
| 502 |
+
tasks_for_stats_calc_t1.append(task_info_cfg)
|
| 503 |
+
total_combinations_to_calc_stats_for += len(subjects_to_iterate)
|
| 504 |
+
if task_2_match: # Can overlap with task_1_match
|
| 505 |
+
# Avoid double counting if task_1_match also true
|
| 506 |
+
if not task_1_match or task_info_cfg not in tasks_for_stats_calc_t1:
|
| 507 |
+
tasks_for_stats_calc_t2.append(task_info_cfg) # only add if not already counted for t1
|
| 508 |
+
total_combinations_to_calc_stats_for += len(subjects_to_iterate)
|
| 509 |
+
elif task_1_match and task_info_cfg not in tasks_for_stats_calc_t2: # if it matched t1, ensure it's in t2 list for later use
|
| 510 |
+
tasks_for_stats_calc_t2.append(task_info_cfg)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# Recalculate total_combinations based on unique tasks for T1 and T2 sets
|
| 514 |
+
# This is complex if subjects vary per task_info in HOSTING_STATS.
|
| 515 |
+
# The simple approach: iterate LOW_LEVEL_TASKS, then subjects found in HOSTING_STATS for that task_info.
|
| 516 |
+
|
| 517 |
+
current_stat_calcs = 0
|
| 518 |
+
expected_stat_calcs = 0 # Will be determined by actual found (task_cfg, subj) pairs in HOSTING_STATS
|
| 519 |
+
|
| 520 |
+
# Populate task_1_stats
|
| 521 |
+
for task_config_t1 in tasks_for_stats_calc_t1:
|
| 522 |
+
task_1_stats[task_config_t1] = {}
|
| 523 |
+
for subject in subjects_to_iterate: # Iterate all potential subjects
|
| 524 |
+
task_info_w_subject = (*task_config_t1, subject)
|
| 525 |
+
if HOSTING_STATS and task_info_w_subject in HOSTING_STATS: # Check if this specific combo exists
|
| 526 |
+
expected_stat_calcs +=1 # Count expected calculations
|
| 527 |
+
stats = calculate_task_statistics(
|
| 528 |
+
task_info_w_subject, sensors, time_window_offsets_list, # Pass the list of offsets
|
| 529 |
+
verbose=verbose
|
| 530 |
+
)
|
| 531 |
+
if stats is not None:
|
| 532 |
+
task_1_stats[task_config_t1][subject] = stats
|
| 533 |
+
current_stat_calcs+=1
|
| 534 |
+
if progress_callback and expected_stat_calcs > 0 : progress_callback(0.1 + (0.4 * (current_stat_calcs / max(1,expected_stat_calcs*2)))) # *2 for T1 and T2 pass
|
| 535 |
+
|
| 536 |
+
# Populate task_2_stats
|
| 537 |
+
for task_config_t2 in tasks_for_stats_calc_t2:
|
| 538 |
+
task_2_stats[task_config_t2] = {}
|
| 539 |
+
for subject in subjects_to_iterate:
|
| 540 |
+
task_info_w_subject = (*task_config_t2, subject)
|
| 541 |
+
if HOSTING_STATS and task_info_w_subject in HOSTING_STATS:
|
| 542 |
+
if task_config_t2 not in tasks_for_stats_calc_t1 or subject not in task_1_stats.get(task_config_t2,{}): # Avoid re-calculating if already done for T1
|
| 543 |
+
expected_stat_calcs +=1
|
| 544 |
+
|
| 545 |
+
stats = calculate_task_statistics(
|
| 546 |
+
task_info_w_subject, sensors, time_window_offsets_list,
|
| 547 |
+
verbose=verbose
|
| 548 |
+
)
|
| 549 |
+
if stats is not None:
|
| 550 |
+
task_2_stats[task_config_t2][subject] = stats
|
| 551 |
+
current_stat_calcs+=1 # This current_stat_calcs will go up to expected_stat_calcs*2 potentially
|
| 552 |
+
if progress_callback and expected_stat_calcs > 0: progress_callback(0.1 + (0.4 * (current_stat_calcs / max(1,expected_stat_calcs*2))))
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# Clean up empty task_configs from stats if no subjects had data
|
| 556 |
+
task_1_stats = {k: v for k, v in task_1_stats.items() if v}
|
| 557 |
+
task_2_stats = {k: v for k, v in task_2_stats.items() if v}
|
| 558 |
+
|
| 559 |
+
# Generate valid task pairs from successfully populated stats
|
| 560 |
+
task_pairs = list(itertools.product(task_1_stats.keys(),
|
| 561 |
+
task_2_stats.keys()))
|
| 562 |
+
|
| 563 |
+
if verbose or progress_callback:
|
| 564 |
+
if progress_callback:
|
| 565 |
+
progress_callback(0.5) # Half-way point
|
| 566 |
+
print(f"Computing similarity portraits for {len(task_pairs)} task pairs...")
|
| 567 |
+
|
| 568 |
+
# Process task pairs in parallel
|
| 569 |
+
process_func = partial(process_pair,
|
| 570 |
+
task_1_stats=task_1_stats,
|
| 571 |
+
task_2_stats=task_2_stats,
|
| 572 |
+
binary_threshold=binary_threshold,
|
| 573 |
+
match_subjects=match_subjects,
|
| 574 |
+
biomechanical_difference=biomechanical_difference)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
# Add progress tracking if requested
|
| 578 |
+
results_dict = {}
|
| 579 |
+
num_results = 0
|
| 580 |
+
for i in tqdm(range(len(task_pairs)), desc="Processing task pairs",
|
| 581 |
+
total=len(task_pairs)):
|
| 582 |
+
current_task_pair = task_pairs[i]
|
| 583 |
+
result = process_func(current_task_pair)
|
| 584 |
+
|
| 585 |
+
# If the result is None, then skip it
|
| 586 |
+
if result is None:
|
| 587 |
+
# print(f"Result is None for task pair {current_task_pair}")
|
| 588 |
+
continue
|
| 589 |
+
|
| 590 |
+
if 'stair' in current_task_pair[0][0] or 'stair' in current_task_pair[1][0]:
|
| 591 |
+
pass
|
| 592 |
+
|
| 593 |
+
portraits, subject_pairs = result
|
| 594 |
+
|
| 595 |
+
low_level_task_1 = (current_task_pair[0][0],
|
| 596 |
+
current_task_pair[0][1],
|
| 597 |
+
current_task_pair[0][2])
|
| 598 |
+
low_level_task_2 = (current_task_pair[1][0],
|
| 599 |
+
current_task_pair[1][1],
|
| 600 |
+
current_task_pair[1][2])
|
| 601 |
+
low_level_task_total = (low_level_task_1, low_level_task_2)
|
| 602 |
+
# Add the result to the list of results for this low-level task
|
| 603 |
+
if low_level_task_total not in results_dict:
|
| 604 |
+
results_dict[low_level_task_total] = {}
|
| 605 |
+
|
| 606 |
+
if portraits is not None:
|
| 607 |
+
for subject_pair, portrait in zip(subject_pairs, portraits):
|
| 608 |
+
|
| 609 |
+
# Take the compliment if you are doing the output difference
|
| 610 |
+
if output_difference:
|
| 611 |
+
results_dict[low_level_task_total][subject_pair] = 1 - portrait
|
| 612 |
+
else:
|
| 613 |
+
results_dict[low_level_task_total][subject_pair] = portrait
|
| 614 |
+
num_results += 1
|
| 615 |
+
# Update progress if we requested
|
| 616 |
+
if progress_callback:
|
| 617 |
+
progress_callback(0.5 + (0.4 * (i / len(task_pairs))))
|
| 618 |
+
|
| 619 |
+
# Final progress update
|
| 620 |
+
if progress_callback:
|
| 621 |
+
progress_callback(0.9)
|
| 622 |
+
|
| 623 |
+
# Aggregate the similarity portraits
|
| 624 |
+
if num_results <= 0:
|
| 625 |
+
# raise ValueError("No valid similarity portraits were generated")
|
| 626 |
+
if verbose: print("Warning: No valid similarity portraits were generated. Check task/subject filters and data.")
|
| 627 |
+
return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150)) # Return empty/zero portrait
|
| 628 |
+
|
| 629 |
+
# Remove empty results (should be handled by num_results check, but good practice)
|
| 630 |
+
results_dict = {k: v for k, v in results_dict.items() if v and any(item is not None for item in v.values())}
|
| 631 |
+
if not results_dict and num_results > 0: # Contradiction, means num_results was miscounted or structure is wrong
|
| 632 |
+
if verbose: print("Warning: results_dict is empty despite num_results > 0. Returning zero portrait.")
|
| 633 |
+
return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150))
|
| 634 |
+
if not results_dict and num_results == 0: # Consistent, no results
|
| 635 |
+
return (np.zeros((150,150)), {}) if return_all_portraits else np.zeros((150,150))
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# Get the aggregation of the similarity portraits
|
| 639 |
+
total_task_similarity_portrait = low_level_tasks_aggregation(results_dict)
|
| 640 |
+
|
| 641 |
+
if progress_callback:
|
| 642 |
+
progress_callback(1.0) # Complete
|
| 643 |
+
|
| 644 |
+
# Add back save_results_dict logic
|
| 645 |
+
if save_results_dict:
|
| 646 |
+
directory = 'saved_similarity_portraits'
|
| 647 |
+
if not os.path.exists(directory):
|
| 648 |
+
os.makedirs(directory)
|
| 649 |
+
# Construct filename carefully. time_window_offsets_list might be long.
|
| 650 |
+
# Using a hash or a string representation that's filename-safe.
|
| 651 |
+
time_window_str = "-".join(map(str, time_window_offsets_list))
|
| 652 |
+
file_name = f'results_dict_{sorted(sensors)}_time_window_{time_window_str}'
|
| 653 |
+
if biomechanical_difference:
|
| 654 |
+
file_name += f'_biomechanical_difference_{biomechanical_difference}'
|
| 655 |
+
# Ensure filename is not too long and is valid
|
| 656 |
+
file_name = "".join(c if c.isalnum() or c in ('_', '-') else '' for c in file_name)[:100] + ".pkl"
|
| 657 |
+
|
| 658 |
+
try:
|
| 659 |
+
with open(os.path.join(directory, file_name), 'wb') as f:
|
| 660 |
+
pickle.dump(results_dict, f)
|
| 661 |
+
if verbose:
|
| 662 |
+
print(f"Saved results_dict to {os.path.join(directory, file_name)}")
|
| 663 |
+
except Exception as e:
|
| 664 |
+
if verbose:
|
| 665 |
+
print(f"Error saving results_dict: {e}")
|
| 666 |
+
|
| 667 |
+
if return_all_portraits:
|
| 668 |
+
return total_task_similarity_portrait, results_dict
|
| 669 |
+
else:
|
| 670 |
+
return total_task_similarity_portrait
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def low_level_tasks_aggregation(results_dict: dict[tuple, list[np.ndarray]]
|
| 674 |
+
) -> np.ndarray:
|
| 675 |
+
"""
|
| 676 |
+
Aggregate the similarity portraits for low-level tasks.
|
| 677 |
+
First, we will get the average across all subjects for each low-level task.
|
| 678 |
+
Then, we will get the max of the average similarity portraits.
|
| 679 |
+
"""
|
| 680 |
+
total_task_similarity_portraits = []
|
| 681 |
+
for low_level_task in results_dict.keys():
|
| 682 |
+
|
| 683 |
+
# Get the average across all subjects for each subject pair
|
| 684 |
+
results_list = list(results_dict[low_level_task].values())
|
| 685 |
+
|
| 686 |
+
# Get the average across all subjects for each low-level task
|
| 687 |
+
low_level_task_portraits = np.nanmean(results_list, axis=0)
|
| 688 |
+
|
| 689 |
+
# Append the average similarity portrait to the list
|
| 690 |
+
total_task_similarity_portraits.append(low_level_task_portraits)
|
| 691 |
+
|
| 692 |
+
# Get the max of the average similarity portraits
|
| 693 |
+
total_task_similarity_portrait = np.max(total_task_similarity_portraits,
|
| 694 |
+
axis=0)
|
| 695 |
+
|
| 696 |
+
if total_task_similarity_portrait is None:
|
| 697 |
+
pass
|
| 698 |
+
|
| 699 |
+
return total_task_similarity_portrait
|
| 700 |
+
|
| 701 |
+
def load_similarity_portrait(sensor_config: list[str],
|
| 702 |
+
time_window: list[int]=[0], # Default [0] might need to be [1] to match dashboard
|
| 703 |
+
task1_name: tuple=(None, None, None),
|
| 704 |
+
task2_name: tuple=(None, None, None),
|
| 705 |
+
biomechanical_difference: bool = False):
|
| 706 |
+
"""
|
| 707 |
+
Load the similarity pickle file and calculate the similarity portrait for
|
| 708 |
+
a given sensor configuration, time window, and task names.
|
| 709 |
+
|
| 710 |
+
Arguments:
|
| 711 |
+
-----------
|
| 712 |
+
sensor_config : list[str]
|
| 713 |
+
The sensor configuration to use.
|
| 714 |
+
time_window : list[int]
|
| 715 |
+
The time window to use.
|
| 716 |
+
task1_name : tuple
|
| 717 |
+
Task name, incline, and speed for the first task. Use None for any value.
|
| 718 |
+
task2_name : tuple
|
| 719 |
+
Task name, incline, and speed for the second task. Use None for any value.
|
| 720 |
+
|
| 721 |
+
Returns:
|
| 722 |
+
--------
|
| 723 |
+
np.ndarray
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
# Get the file name
|
| 727 |
+
file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
|
| 728 |
+
if biomechanical_difference:
|
| 729 |
+
file_name += f'_biomechanical_difference_{biomechanical_difference}'
|
| 730 |
+
directory = 'saved_similarity_portraits'
|
| 731 |
+
|
| 732 |
+
# Load the results dict
|
| 733 |
+
with open(os.path.join(directory, f'{file_name}.pkl'), 'rb') as f:
|
| 734 |
+
results_dict = pickle.load(f)
|
| 735 |
+
|
| 736 |
+
# Check if we want to to a lower-levl task pair
|
| 737 |
+
if task1_name != (None, None, None) and task2_name != (None, None, None):
|
| 738 |
+
# Get the similarity portrait
|
| 739 |
+
try:
|
| 740 |
+
subject_portraits = results_dict[(task1_name, task2_name)]
|
| 741 |
+
except KeyError:
|
| 742 |
+
print(f"KeyError: {(task1_name, task2_name)} not found in results dict")
|
| 743 |
+
print(f"Results dict keys: {results_dict.keys()}")
|
| 744 |
+
raise KeyError(f"{(task1_name, task2_name)} not found in results dict")
|
| 745 |
+
|
| 746 |
+
similarity_portrait = np.mean(list(subject_portraits.values()), axis=0)
|
| 747 |
+
|
| 748 |
+
# We want to do the high level aggregation
|
| 749 |
+
else:
|
| 750 |
+
# Get the similarity portrait
|
| 751 |
+
similarity_portrait = low_level_tasks_aggregation(results_dict)
|
| 752 |
+
|
| 753 |
+
if similarity_portrait is None:
|
| 754 |
+
pass
|
| 755 |
+
|
| 756 |
+
return similarity_portrait
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def load_results_dict(sensor_config: list[str],
|
| 760 |
+
time_window: list[int]=[0]): # Default [0] might need to be [1]
|
| 761 |
+
"""
|
| 762 |
+
Load the results dict for a given sensor configuration, time window, and task names.
|
| 763 |
+
"""
|
| 764 |
+
# Get the file name
|
| 765 |
+
file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
|
| 766 |
+
directory = 'saved_similarity_portraits'
|
| 767 |
+
|
| 768 |
+
# Load the results dict
|
| 769 |
+
with open(os.path.join(directory, f'{file_name}.pkl'), 'rb') as f:
|
| 770 |
+
results_dict = pickle.load(f)
|
| 771 |
+
|
| 772 |
+
return results_dict
|
| 773 |
+
|
| 774 |
+
def load_input_output_similarity_portrait(sensor_config: list[str],
|
| 775 |
+
output_config: list[str],
|
| 776 |
+
time_window: list[int]=[0], # Default [0] might need to be [1]
|
| 777 |
+
task1_name: tuple=(None, None, None),
|
| 778 |
+
task2_name: tuple=(None, None, None),
|
| 779 |
+
biomechanical_difference: bool = False):
|
| 780 |
+
"""
|
| 781 |
+
Load the similarity pickle file and calculate the similarity portrait for
|
| 782 |
+
a given sensor configuration, time window, and task names.
|
| 783 |
+
"""
|
| 784 |
+
|
| 785 |
+
# Load the input and output pickle files
|
| 786 |
+
directory = 'saved_similarity_portraits'
|
| 787 |
+
input_file_name = f'results_dict_{sorted(sensor_config)}_time_window_{time_window}'
|
| 788 |
+
output_file_name = f'results_dict_{sorted(output_config)}_time_window_{[0]}'
|
| 789 |
+
if biomechanical_difference:
|
| 790 |
+
output_file_name += f'_biomechanical_difference_{biomechanical_difference}'
|
| 791 |
+
|
| 792 |
+
input_results_dict = None
|
| 793 |
+
output_results_dict = None
|
| 794 |
+
|
| 795 |
+
# Load the input pickle file with error handling
|
| 796 |
+
try:
|
| 797 |
+
with open(os.path.join(directory, f'{input_file_name}.pkl'), 'rb') as f:
|
| 798 |
+
input_results_dict = pickle.load(f)
|
| 799 |
+
except FileNotFoundError:
|
| 800 |
+
print(f"Warning: Input similarity file not found: {input_file_name}.pkl")
|
| 801 |
+
# Return placeholder if input file is crucial and missing
|
| 802 |
+
# return np.zeros((150, 150))
|
| 803 |
+
|
| 804 |
+
# Load the output pickle file with error handling
|
| 805 |
+
try:
|
| 806 |
+
with open(os.path.join(directory, f'{output_file_name}.pkl'), 'rb') as f:
|
| 807 |
+
output_results_dict = pickle.load(f)
|
| 808 |
+
except FileNotFoundError:
|
| 809 |
+
print(f"Warning: Output similarity file not found: {output_file_name}.pkl")
|
| 810 |
+
# Decide if placeholder is needed if output file is missing
|
| 811 |
+
# If both are needed, might return placeholder here
|
| 812 |
+
|
| 813 |
+
# If either dict failed to load, return placeholder
|
| 814 |
+
if input_results_dict is None or output_results_dict is None:
|
| 815 |
+
print(f"Warning: Could not load required similarity data. Returning placeholder zeros.")
|
| 816 |
+
return np.zeros((150, 150))
|
| 817 |
+
|
| 818 |
+
# If the task1_name and task2_name are not None, then we want to do a
|
| 819 |
+
# lower-level task pair
|
| 820 |
+
try:
|
| 821 |
+
if task1_name != (None, None, None) and task2_name != (None, None, None):
|
| 822 |
+
target_key = (task1_name, task2_name)
|
| 823 |
+
input_results_dict = {target_key: input_results_dict[target_key]}
|
| 824 |
+
output_results_dict = {target_key: output_results_dict[target_key]}
|
| 825 |
+
except KeyError:
|
| 826 |
+
print(f"KeyError: {(task1_name, task2_name)} not found in input or output results dicts")
|
| 827 |
+
return None
|
| 828 |
+
|
| 829 |
+
# Calculate the input and output similarity portrait
|
| 830 |
+
input_output_similarity_portrait = input_output_task_aggregation(
|
| 831 |
+
input_results_dict, output_results_dict
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
return input_output_similarity_portrait
|
| 835 |
+
|
| 836 |
+
# Add back input_output_task_aggregation
|
| 837 |
+
def input_output_task_aggregation(
|
| 838 |
+
input_results_dict: dict[tuple, list[np.ndarray]],
|
| 839 |
+
output_results_dict: dict[tuple, list[np.ndarray]],
|
| 840 |
+
) -> np.ndarray:
|
| 841 |
+
"""
|
| 842 |
+
Aggregate the similarity portraits for input and output tasks.
|
| 843 |
+
"""
|
| 844 |
+
# Check if the input and output results dicts have the same keys
|
| 845 |
+
if input_results_dict.keys() != output_results_dict.keys():
|
| 846 |
+
input_keys = set(input_results_dict.keys())
|
| 847 |
+
output_keys = set(output_results_dict.keys())
|
| 848 |
+
matching_keys = input_keys & output_keys
|
| 849 |
+
input_only_keys = input_keys - output_keys
|
| 850 |
+
output_only_keys = output_keys - input_keys
|
| 851 |
+
|
| 852 |
+
num_matching_keys = len(matching_keys)
|
| 853 |
+
num_input_only_keys = len(input_only_keys)
|
| 854 |
+
num_output_only_keys = len(output_only_keys)
|
| 855 |
+
|
| 856 |
+
# This print might be too verbose for regular operation
|
| 857 |
+
# print(f"Keys do not match: {num_matching_keys} keys matching, {num_input_only_keys} keys only in input, {num_output_only_keys} keys only in output")
|
| 858 |
+
|
| 859 |
+
# # Save the set of matching and not matching keys to a single file (for debugging)
|
| 860 |
+
# with open('matching_and_not_matching_keys.txt', 'w') as f:
|
| 861 |
+
# f.write(f"Matching keys:\n")
|
| 862 |
+
# for key in matching_keys:
|
| 863 |
+
# f.write(f"{key}\n")
|
| 864 |
+
|
| 865 |
+
# f.write(f"\nNot matching keys:\n")
|
| 866 |
+
# if input_only_keys:
|
| 867 |
+
# f.write(f"Keys only in input dictionary:\n")
|
| 868 |
+
# for key in input_only_keys:
|
| 869 |
+
# f.write(f"{key}\n")
|
| 870 |
+
|
| 871 |
+
# if output_only_keys:
|
| 872 |
+
# f.write(f"\nKeys only in output dictionary:\n")
|
| 873 |
+
# for key in output_only_keys:
|
| 874 |
+
# f.write(f"{key}\n")
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
common_keys = set(input_results_dict.keys()) & set(output_results_dict.keys())
|
| 878 |
+
if not common_keys:
|
| 879 |
+
# print("Warning: No common keys between input and output results dicts for aggregation.")
|
| 880 |
+
return np.zeros((150,150)) # Return a default empty portrait
|
| 881 |
+
|
| 882 |
+
total_task_similarity_portraits = []
|
| 883 |
+
|
| 884 |
+
# Get the average across all subjects for each low-level task
|
| 885 |
+
for low_level_task in common_keys:
|
| 886 |
+
if not input_results_dict[low_level_task] or not output_results_dict[low_level_task]:
|
| 887 |
+
# print(f"Warning: Empty subject data for low_level_task {low_level_task} in input or output dict. Skipping.")
|
| 888 |
+
continue
|
| 889 |
+
|
| 890 |
+
# Get the common subjects between the input and output results dicts for this specific low_level_task
|
| 891 |
+
common_subject_pairs = set(input_results_dict[low_level_task].keys()) & \
|
| 892 |
+
set(output_results_dict[low_level_task].keys())
|
| 893 |
+
|
| 894 |
+
if not common_subject_pairs:
|
| 895 |
+
# print(f"Warning: No common subject pairs for low_level_task {low_level_task}. Skipping.")
|
| 896 |
+
continue
|
| 897 |
+
|
| 898 |
+
# Take the product of the input and output similarity portraits for
|
| 899 |
+
# each subject pair
|
| 900 |
+
subject_input_output_portraits = []
|
| 901 |
+
for subject_pair in common_subject_pairs:
|
| 902 |
+
input_portrait = input_results_dict[low_level_task].get(subject_pair)
|
| 903 |
+
output_portrait = output_results_dict[low_level_task].get(subject_pair)
|
| 904 |
+
|
| 905 |
+
if input_portrait is not None and output_portrait is not None:
|
| 906 |
+
# Ensure they are numpy arrays for element-wise multiplication
|
| 907 |
+
input_portrait_np = np.array(input_portrait)
|
| 908 |
+
output_portrait_np = np.array(output_portrait)
|
| 909 |
+
if input_portrait_np.shape == output_portrait_np.shape:
|
| 910 |
+
subject_input_output_portraits.append(
|
| 911 |
+
input_portrait_np * output_portrait_np
|
| 912 |
+
)
|
| 913 |
+
# else:
|
| 914 |
+
# print(f"Warning: Shape mismatch for subject pair {subject_pair}, task {low_level_task}. Input: {input_portrait_np.shape}, Output: {output_portrait_np.shape}. Skipping.")
|
| 915 |
+
# else:
|
| 916 |
+
# print(f"Skipping subject pair {subject_pair} for low-level task {low_level_task} because one of the portraits is None")
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
if not subject_input_output_portraits:
|
| 920 |
+
# print(f"Warning: No valid subject_input_output_portraits for low_level_task {low_level_task}. Skipping.")
|
| 921 |
+
continue
|
| 922 |
+
|
| 923 |
+
# Get the mean of the subject-specific input-output product portraits
|
| 924 |
+
# Changed from max to mean as per original intent for product aggregation before overall max
|
| 925 |
+
mean_subject_input_output_portrait = np.nanmean(
|
| 926 |
+
np.array(subject_input_output_portraits), axis=0
|
| 927 |
+
)
|
| 928 |
+
if np.isnan(mean_subject_input_output_portrait).all():
|
| 929 |
+
# print(f"Warning: Mean subject input-output portrait is all NaN for {low_level_task}. Skipping.")
|
| 930 |
+
continue
|
| 931 |
+
|
| 932 |
+
# Append the aggregated similarity portrait for this low_level_task to the list
|
| 933 |
+
total_task_similarity_portraits.append(mean_subject_input_output_portrait)
|
| 934 |
+
|
| 935 |
+
if not total_task_similarity_portraits:
|
| 936 |
+
# print("Warning: No task similarity portraits to aggregate after processing all common keys.")
|
| 937 |
+
return np.zeros((150,150))
|
| 938 |
+
|
| 939 |
+
# Get the max across all low_level_task aggregated portraits
|
| 940 |
+
final_aggregated_portrait = np.nanmax(
|
| 941 |
+
np.array(total_task_similarity_portraits), axis=0
|
| 942 |
+
)
|
| 943 |
+
if np.isnan(final_aggregated_portrait).all():
|
| 944 |
+
# print("Warning: Final aggregated portrait is all NaN. Returning zeros.")
|
| 945 |
+
return np.zeros((150,150))
|
| 946 |
+
|
| 947 |
+
return final_aggregated_portrait
|
plot_similarity.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visualization functions for similarity measures"""
|
| 2 |
+
from plot_styling import PLOT_STYLE, PLOT_COLORS, set_plot_style
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 10 |
+
|
| 11 |
+
def plot_similarity_measure(measure_data: np.ndarray, ax: plt.Axes = None,
|
| 12 |
+
plot_type: str = 'input', cbar: bool = True,
|
| 13 |
+
task_x_name: str = '', task_y_name: str = '',
|
| 14 |
+
xlabel: Optional[str] = None, ylabel: Optional[str] = None,
|
| 15 |
+
cmap: Optional[Union[str, mpl.colors.Colormap]] = None,
|
| 16 |
+
title: Optional[str] = None,
|
| 17 |
+
fontsize: int = 12,
|
| 18 |
+
y_label_pad: int = -13,
|
| 19 |
+
cbar_labels: bool = True,
|
| 20 |
+
cutoff_treshold:float=None,
|
| 21 |
+
high_level_plot: bool = False):
|
| 22 |
+
"""Plot similarity measure with consistent styling.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
ax: matplotlib axis object
|
| 26 |
+
measure_data: 2D array of similarity measures
|
| 27 |
+
plot_type: str, one of 'input', 'output', or 'conflict'
|
| 28 |
+
cbar: bool, whether to show colorbar
|
| 29 |
+
task_x_name: str, name of x-axis task (used if xlabel not provided)
|
| 30 |
+
task_y_name: str, name of y-axis task (used if ylabel not provided)
|
| 31 |
+
xlabel: str, optional custom x-axis label
|
| 32 |
+
ylabel: str, optional custom y-axis label
|
| 33 |
+
cmap: str or matplotlib colormap, optional custom colormap
|
| 34 |
+
title: str, optional custom title
|
| 35 |
+
fontsize: int, font size for all text elements
|
| 36 |
+
cbar_labels: bool, whether to show colorbar labels
|
| 37 |
+
cutoff_treshold: float, optional cutoff threshold that will count
|
| 38 |
+
as a conflict. If it is not None, the amount of values above this
|
| 39 |
+
threshold will be added as a percent text next to the colorbar.
|
| 40 |
+
If it is None, no text will be added.
|
| 41 |
+
high_level_plot: bool, whether to plot in a high-level format. This just
|
| 42 |
+
means that we subtract the diagonal from the thresholded values, since
|
| 43 |
+
this will trivially be 1 due to the fact that we compare the same task
|
| 44 |
+
to itself.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if ax is None:
|
| 48 |
+
fig, ax = plt.subplots()
|
| 49 |
+
|
| 50 |
+
# Set color based on plot type
|
| 51 |
+
if cmap is None:
|
| 52 |
+
if plot_type == 'input':
|
| 53 |
+
cmap = PLOT_COLORS['input_similarity']
|
| 54 |
+
elif plot_type == 'output':
|
| 55 |
+
cmap = PLOT_COLORS['output_difference']
|
| 56 |
+
elif plot_type == 'output_biomechanical':
|
| 57 |
+
cmap = PLOT_COLORS['output_biomechanical']
|
| 58 |
+
else: # conflict
|
| 59 |
+
cmap = PLOT_COLORS['conflict']
|
| 60 |
+
|
| 61 |
+
# Create heatmap with flipped y-axis
|
| 62 |
+
hm = sns.heatmap(np.flipud(measure_data), vmin=0, vmax=1, ax=ax, cmap=cmap,
|
| 63 |
+
cbar=False, rasterized=False)
|
| 64 |
+
|
| 65 |
+
# Configure axes
|
| 66 |
+
ax.set_xticks([0, 149])
|
| 67 |
+
ax.set_xticklabels(['0%', '100%'], fontsize=fontsize)
|
| 68 |
+
ax.set_yticks([0, 149])
|
| 69 |
+
ax.set_yticklabels(['100%', '0%'], fontsize=fontsize) # Flipped y-axis labels
|
| 70 |
+
|
| 71 |
+
# Set labels with consistent padding
|
| 72 |
+
if xlabel is not None:
|
| 73 |
+
ax.set_xlabel(f'Gait Cycle (%)\n{xlabel}', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
|
| 74 |
+
elif task_x_name:
|
| 75 |
+
ax.set_xlabel(f'Gait Cycle (%)\n{task_x_name}', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
|
| 76 |
+
else:
|
| 77 |
+
ax.set_xlabel('Gait Cycle (%)', labelpad=PLOT_STYLE['label_pad_x']-3, fontsize=fontsize)
|
| 78 |
+
|
| 79 |
+
if ylabel is not None:
|
| 80 |
+
ax.set_ylabel(f'{ylabel}\nGait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
|
| 81 |
+
elif task_y_name:
|
| 82 |
+
ax.set_ylabel(f'{task_y_name}\nGait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
|
| 83 |
+
else:
|
| 84 |
+
ax.set_ylabel('Gait Cycle (%)', labelpad=PLOT_STYLE['label_pad_y']-13+y_label_pad, fontsize=fontsize)
|
| 85 |
+
|
| 86 |
+
# Configure ticks
|
| 87 |
+
ax.tick_params(axis='both', which='both',
|
| 88 |
+
length=PLOT_STYLE['tick_length'],
|
| 89 |
+
width=PLOT_STYLE['tick_width'],
|
| 90 |
+
pad=PLOT_STYLE['tick_pad'],
|
| 91 |
+
labelsize=fontsize)
|
| 92 |
+
|
| 93 |
+
# Ensure x-axis labels are horizontal
|
| 94 |
+
ax.tick_params(axis='x', rotation=0)
|
| 95 |
+
|
| 96 |
+
# Set the title
|
| 97 |
+
if title is not None:
|
| 98 |
+
ax.set_title(title, pad=2, fontsize=fontsize)
|
| 99 |
+
|
| 100 |
+
if cbar:
|
| 101 |
+
divider = make_axes_locatable(ax)
|
| 102 |
+
cax = divider.append_axes("right", size="5%", pad=0.05) # Increased from 2% to 5%
|
| 103 |
+
cbar_obj = plt.colorbar(hm.collections[0], cax=cax)
|
| 104 |
+
cbar_obj.outline.set_visible(False)
|
| 105 |
+
cbar_obj.ax.yaxis.set_ticks_position('right')
|
| 106 |
+
cbar_obj.ax.tick_params(length=0, labelsize=fontsize)
|
| 107 |
+
if cbar_labels:
|
| 108 |
+
cbar_obj.set_ticks([0, 1])
|
| 109 |
+
else:
|
| 110 |
+
cbar_obj.set_ticks([])
|
| 111 |
+
|
| 112 |
+
# Implement cutoff threshold annotation
|
| 113 |
+
if cutoff_treshold is not None:
|
| 114 |
+
|
| 115 |
+
# The actual cutoff will depend on the plot type
|
| 116 |
+
if plot_type == 'input':
|
| 117 |
+
true_cutoff = cutoff_treshold
|
| 118 |
+
elif plot_type == 'output':
|
| 119 |
+
true_cutoff = 1 - cutoff_treshold
|
| 120 |
+
elif plot_type == 'output_biomechanical':
|
| 121 |
+
true_cutoff = 1 - cutoff_treshold
|
| 122 |
+
else: # conflict
|
| 123 |
+
true_cutoff = cutoff_treshold * (1 - cutoff_treshold)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Count percent of values above threshold
|
| 127 |
+
if high_level_plot:
|
| 128 |
+
# Subtract diagonal from thresholded values
|
| 129 |
+
mask = np.ones_like(measure_data, dtype=bool)
|
| 130 |
+
np.fill_diagonal(mask, False)
|
| 131 |
+
total = np.sum(mask)
|
| 132 |
+
above = np.sum((measure_data > true_cutoff) & mask)
|
| 133 |
+
else:
|
| 134 |
+
total = measure_data.size
|
| 135 |
+
above = np.sum(measure_data > true_cutoff)
|
| 136 |
+
percent = 100.0 * above / total if total > 0 else 0.0
|
| 137 |
+
# Format as e.g. "12.3% > 0.8"
|
| 138 |
+
annotation = f"{percent:.1f}% > {true_cutoff:g}"
|
| 139 |
+
# Place annotation to the right of the colorbar
|
| 140 |
+
cbar_obj.ax.text(1.2, 0.5, annotation, va='center', ha='left',
|
| 141 |
+
fontsize=fontsize, rotation=90, transform=cbar_obj.ax.transAxes)
|
| 142 |
+
|
| 143 |
+
# Set aspect ratio to equal
|
| 144 |
+
ax.set_aspect('equal')
|
| 145 |
+
|
| 146 |
+
return hm
|
| 147 |
+
|
| 148 |
+
# Create custom colormaps with transparency for low values
|
| 149 |
+
def create_transparent_cmap(colors):
|
| 150 |
+
# Convert hex colors to RGB tuples with alpha
|
| 151 |
+
rgb_colors = [(1,1,1,0)] # Start with transparent white
|
| 152 |
+
for hex_color in colors:
|
| 153 |
+
r = int(hex_color[1:3], 16)/255
|
| 154 |
+
g = int(hex_color[3:5], 16)/255
|
| 155 |
+
b = int(hex_color[5:7], 16)/255
|
| 156 |
+
rgb_colors.append((r, g, b, 1.0))
|
| 157 |
+
return LinearSegmentedColormap.from_list('custom_cmap', rgb_colors)
|
| 158 |
+
|
| 159 |
+
# Define color sequences
|
| 160 |
+
blue_colors = ['#F7FBFF', '#DEEBF7', '#C6DBEF', '#9ECAE1', '#6BAED6', '#4292C6', '#2171B5']
|
| 161 |
+
orange_colors = ['#FFF5EB', '#FEE6CE', '#FDD0A2', '#FDAE6B', '#FD8D3C', '#F16913', '#D94801']
|
| 162 |
+
green_colors = ['#F7FCF5', '#E5F5E0', '#C7E9C0', '#A1D99B', '#74C476', '#41AB5D', '#238B45']
|
| 163 |
+
|
| 164 |
+
sim_cmap = create_transparent_cmap(blue_colors) # Blue sequential
|
| 165 |
+
diff_cmap = create_transparent_cmap(orange_colors) # Peach sequential
|
| 166 |
+
conflict_cmap = create_transparent_cmap(green_colors) # Purple sequential
|
plot_styling.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for consistent plot styling across all visualizations"""
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
import matplotlib.font_manager as fm
|
| 6 |
+
|
| 7 |
+
# Add Arial font
|
| 8 |
+
fm.fontManager.addfont('/usr/share/fonts/truetype/msttcorefonts/Arial.ttf')
|
| 9 |
+
|
| 10 |
+
# Plot styling constants
|
| 11 |
+
PLOT_STYLE = {
|
| 12 |
+
'font_family': 'Arial',
|
| 13 |
+
'font_size': 18,
|
| 14 |
+
'title_size': 20,
|
| 15 |
+
'label_size': 18,
|
| 16 |
+
'tick_size': 15,
|
| 17 |
+
'tick_length': 5,
|
| 18 |
+
'tick_width': 0.5,
|
| 19 |
+
'tick_pad': 2,
|
| 20 |
+
'label_pad_x': -15,
|
| 21 |
+
'label_pad_y': -35,
|
| 22 |
+
'figure_dpi': 300,
|
| 23 |
+
'aspect_ratio': 'equal',
|
| 24 |
+
'subplot_wspace': 0.05,
|
| 25 |
+
'subplot_hspace': 0.1
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# Color constants
|
| 29 |
+
PLOT_COLORS = {
|
| 30 |
+
'input_similarity': sns.color_palette('rocket', as_cmap=True),
|
| 31 |
+
'output_difference': sns.cubehelix_palette(start=.2, rot=-.3, dark=0, light=0.85,
|
| 32 |
+
reverse=True, as_cmap=True),
|
| 33 |
+
'conflict': sns.cubehelix_palette(start=2, rot=0, dark=0, light=0.85,
|
| 34 |
+
reverse=True, as_cmap=True),
|
| 35 |
+
# Define a cubehelix palette starting dark (dark=0) and ending light purple (light=0.85, start=2.8)
|
| 36 |
+
'output_biomechanical': sns.cubehelix_palette(start=2.8, rot=0.4, dark=0, light=0.85,
|
| 37 |
+
reverse=True, as_cmap=True)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
purple_helix = sns.cubehelix_palette(start=.2, rot=-.4, dark=0, light=0.85,
|
| 42 |
+
reverse=True, as_cmap=True)
|
| 43 |
+
my_purple_helix = sns.cubehelix_palette(start=.2, rot=-.1, dark=0, light=0.85,
|
| 44 |
+
reverse=True, as_cmap=True)
|
| 45 |
+
|
| 46 |
+
def set_plot_style():
|
| 47 |
+
"""Set consistent plot styling across all figures"""
|
| 48 |
+
plt.rcParams['font.family'] = PLOT_STYLE['font_family']
|
| 49 |
+
plt.rcParams['font.size'] = PLOT_STYLE['font_size']
|
| 50 |
+
plt.rcParams['axes.labelsize'] = PLOT_STYLE['label_size']
|
| 51 |
+
plt.rcParams['axes.titlesize'] = PLOT_STYLE['title_size']
|
| 52 |
+
plt.rcParams['xtick.labelsize'] = PLOT_STYLE['tick_size']
|
| 53 |
+
plt.rcParams['ytick.labelsize'] = PLOT_STYLE['tick_size']
|
| 54 |
+
plt.rcParams['figure.dpi'] = PLOT_STYLE['figure_dpi']
|
| 55 |
+
plt.rcParams['figure.subplot.wspace'] = PLOT_STYLE['subplot_wspace']
|
| 56 |
+
plt.rcParams['figure.subplot.hspace'] = PLOT_STYLE['subplot_hspace']
|
requirements.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
altair
|
| 2 |
-
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
sensor_illustration.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from plot_styling import set_plot_style
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class JointPosition:
|
| 10 |
+
"""Stores the position of a joint in the illustration"""
|
| 11 |
+
name: str
|
| 12 |
+
x: float
|
| 13 |
+
y: float
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Segment:
|
| 17 |
+
"""Stores the segment information connecting two joints"""
|
| 18 |
+
name: str
|
| 19 |
+
joint1: str
|
| 20 |
+
joint2: str
|
| 21 |
+
|
| 22 |
+
class LegIllustration:
|
| 23 |
+
def __init__(self, mode: str = "level_walking"):
|
| 24 |
+
"""Initialize the leg illustration with a background image based on the specified mode
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
mode: A string specifying the background mode. Options are:
|
| 28 |
+
"level_walking" (default),
|
| 29 |
+
"level_walking_line_under", and
|
| 30 |
+
"decline_walking".
|
| 31 |
+
"""
|
| 32 |
+
# Set plot style at initialization
|
| 33 |
+
set_plot_style()
|
| 34 |
+
self.mode = mode # store mode if needed later
|
| 35 |
+
|
| 36 |
+
# Map modes to background images
|
| 37 |
+
background_images = {
|
| 38 |
+
"level_walking": "assets/level_ground_no_line_under.png",
|
| 39 |
+
"level_walking_line_under": "assets/level_ground_line_under.png",
|
| 40 |
+
"decline_walking": "assets/downhill_line_under.png"
|
| 41 |
+
}
|
| 42 |
+
if mode in background_images:
|
| 43 |
+
self.background_image_path = background_images[mode]
|
| 44 |
+
else:
|
| 45 |
+
print(f"Warning: Unknown mode '{mode}'. Defaulting to level_walking.")
|
| 46 |
+
self.background_image_path = background_images["level_walking"]
|
| 47 |
+
|
| 48 |
+
# Set joint positions based on mode
|
| 49 |
+
if mode in ["level_walking"]:
|
| 50 |
+
self.image_vertical_offset = 0.02 # Space at bottom for gait cycle bar
|
| 51 |
+
self.joints_right = {
|
| 52 |
+
'hip_r': JointPosition('hip_r', 0.45, 0.825 + self.image_vertical_offset),
|
| 53 |
+
'knee_r': JointPosition('knee_r', 0.65, 0.46 + self.image_vertical_offset),
|
| 54 |
+
'ankle_r': JointPosition('ankle_r', 0.77, 0.075 + self.image_vertical_offset),
|
| 55 |
+
'toe_r': JointPosition('toe_r', 0.92, 0.057 + self.image_vertical_offset),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
self.joints_left = {
|
| 59 |
+
'hip_l': JointPosition('hip_l', 0.35, 0.825 + self.image_vertical_offset),
|
| 60 |
+
'knee_l': JointPosition('knee_l', 0.235, 0.42 + self.image_vertical_offset),
|
| 61 |
+
'ankle_l': JointPosition('ankle_l', 0.06, 0.12 + self.image_vertical_offset),
|
| 62 |
+
'toe_l': JointPosition('toe_l', 0.17, 0.03 + self.image_vertical_offset),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
elif mode == "level_walking_line_under":
|
| 66 |
+
self.image_vertical_offset = 0.01 # Slightly lower offset for decline walking
|
| 67 |
+
self.joints_right = {
|
| 68 |
+
'hip_r': JointPosition('hip_r', 0.48, 0.82 + self.image_vertical_offset),
|
| 69 |
+
'knee_r': JointPosition('knee_r', 0.62, 0.45 + self.image_vertical_offset),
|
| 70 |
+
'ankle_r': JointPosition('ankle_r', 0.695, 0.08 + self.image_vertical_offset),
|
| 71 |
+
'toe_r': JointPosition('toe_r', 0.80, 0.075 + self.image_vertical_offset),
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
self.joints_left = {
|
| 75 |
+
'hip_l': JointPosition('hip_l', 0.38, 0.82 + self.image_vertical_offset),
|
| 76 |
+
'knee_l': JointPosition('knee_l', 0.28, 0.39 + self.image_vertical_offset),
|
| 77 |
+
'ankle_l': JointPosition('ankle_l', 0.18, 0.12 + self.image_vertical_offset),
|
| 78 |
+
'toe_l': JointPosition('toe_l', 0.26, 0.04 + self.image_vertical_offset),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
elif mode == "decline_walking":
|
| 82 |
+
self.image_vertical_offset = 0.01 # Slightly lower offset for decline walking
|
| 83 |
+
self.joints_right = {
|
| 84 |
+
'hip_r': JointPosition('hip_r', 0.63, 0.84 + self.image_vertical_offset),
|
| 85 |
+
'knee_r': JointPosition('knee_r', 0.72, 0.50 + self.image_vertical_offset),
|
| 86 |
+
'ankle_r': JointPosition('ankle_r', 0.795, 0.14 + self.image_vertical_offset),
|
| 87 |
+
'toe_r': JointPosition('toe_r', 0.93, 0.115 + self.image_vertical_offset),
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
self.joints_left = {
|
| 91 |
+
'hip_l': JointPosition('hip_l', 0.53, 0.84 + self.image_vertical_offset),
|
| 92 |
+
'knee_l': JointPosition('knee_l', 0.48, 0.55 + self.image_vertical_offset),
|
| 93 |
+
'ankle_l': JointPosition('ankle_l', 0.19, 0.36 + self.image_vertical_offset),
|
| 94 |
+
'toe_l': JointPosition('toe_l', 0.30, 0.27 + self.image_vertical_offset),
|
| 95 |
+
}
|
| 96 |
+
else:
|
| 97 |
+
# Fallback to level_walking settings for unknown mode values
|
| 98 |
+
self.image_vertical_offset = 0.09
|
| 99 |
+
self.joints_right = {
|
| 100 |
+
'hip_r': JointPosition('hip_r', 0.47, 0.825 + self.image_vertical_offset),
|
| 101 |
+
'knee_r': JointPosition('knee_r', 0.67, 0.46 + self.image_vertical_offset),
|
| 102 |
+
'ankle_r': JointPosition('ankle_r', 0.79, 0.075 + self.image_vertical_offset),
|
| 103 |
+
'toe_r': JointPosition('toe_r', 0.94, 0.057 + self.image_vertical_offset),
|
| 104 |
+
}
|
| 105 |
+
self.joints_left = {
|
| 106 |
+
'hip_l': JointPosition('hip_l', 0.37, 0.825 + self.image_vertical_offset),
|
| 107 |
+
'knee_l': JointPosition('knee_l', 0.255, 0.42 + self.image_vertical_offset),
|
| 108 |
+
'ankle_l': JointPosition('ankle_l', 0.08, 0.12 + self.image_vertical_offset),
|
| 109 |
+
'toe_l': JointPosition('toe_l', 0.19, 0.03 + self.image_vertical_offset),
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Define segments
|
| 113 |
+
self.segments = [
|
| 114 |
+
Segment('thigh', 'hip', 'knee'),
|
| 115 |
+
Segment('shank', 'knee', 'ankle'),
|
| 116 |
+
Segment('foot', 'ankle', 'toe')
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
# Add CoP and vGRF parameters
|
| 120 |
+
self.cop_offset = -0.07 # Offset from foot line
|
| 121 |
+
self.vgrf_arrow_height = 0.15 # Height of the force arrow
|
| 122 |
+
self.vgrf_arrow_offset = 0.07 # Horizontal offset from foot
|
| 123 |
+
|
| 124 |
+
# Add color mapping for different sensor types
|
| 125 |
+
self.color_map = {
|
| 126 |
+
'angle': 'yellow',
|
| 127 |
+
'velocity': 'blue',
|
| 128 |
+
'angle_velocity': 'purple',
|
| 129 |
+
'torque': 'orange'
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def _get_joint_name(self, base_name: str, joints: dict) -> str:
|
| 133 |
+
"""
|
| 134 |
+
Helper method to get the correct joint name with suffix
|
| 135 |
+
based on the joints dictionary
|
| 136 |
+
"""
|
| 137 |
+
suffix = '_r' if 'hip_r' in joints else '_l'
|
| 138 |
+
return f"{base_name}{suffix}"
|
| 139 |
+
|
| 140 |
+
def draw_illustration(self,
|
| 141 |
+
highlighted_elements: List[str] = None,
|
| 142 |
+
title: str = None,
|
| 143 |
+
gait_cycle_sections: List[int] | int = None,
|
| 144 |
+
ax: plt.Axes = None,
|
| 145 |
+
scale_factor: float = 1.0) -> plt.Figure:
|
| 146 |
+
"""Draw the leg illustration with highlighted elements
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
highlighted_elements: List of elements to highlight
|
| 150 |
+
title: Title for the figure
|
| 151 |
+
gait_cycle_sections: List of gait cycle sections to highlight
|
| 152 |
+
ax: Matplotlib axes to draw on. If None, creates new figure.
|
| 153 |
+
"""
|
| 154 |
+
# Track whether we created a new figure
|
| 155 |
+
created_new_figure = False
|
| 156 |
+
|
| 157 |
+
if ax is None:
|
| 158 |
+
# Use consistent figure size and make background transparent
|
| 159 |
+
fig = plt.figure(figsize=(3 * scale_factor, 6 * scale_factor), facecolor='none')
|
| 160 |
+
ax = fig.add_subplot(111)
|
| 161 |
+
created_new_figure = True
|
| 162 |
+
else:
|
| 163 |
+
fig = ax.figure
|
| 164 |
+
|
| 165 |
+
ax.set_facecolor('none') # Make axis background transparent
|
| 166 |
+
|
| 167 |
+
# Draw leg illustration with adjusted vertical extent
|
| 168 |
+
# Calculate padding based on scale_factor
|
| 169 |
+
padding = 0 # 0.001 * scale_factor # Increased padding
|
| 170 |
+
try:
|
| 171 |
+
img = plt.imread(self.background_image_path)
|
| 172 |
+
|
| 173 |
+
# Keep original image positioning to maintain alignment with joint markers
|
| 174 |
+
ax.imshow(
|
| 175 |
+
img,
|
| 176 |
+
extent=[-padding, 1 + padding, # expanded x-axis
|
| 177 |
+
self.image_vertical_offset - padding*2, # original y-axis bottom
|
| 178 |
+
1 + padding*2 + self.image_vertical_offset] # expanded y-axis top
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Only extend the bottom of the axis limits to make room for phase window % if needed
|
| 182 |
+
bottom_limit = self.image_vertical_offset - padding*2
|
| 183 |
+
if gait_cycle_sections is not None and (
|
| 184 |
+
(isinstance(gait_cycle_sections, list) and len(gait_cycle_sections) > 1) or
|
| 185 |
+
(isinstance(gait_cycle_sections, np.ndarray) and len(gait_cycle_sections) > 1)
|
| 186 |
+
):
|
| 187 |
+
bottom_limit = -0.08 # Lower bottom limit for phase window %
|
| 188 |
+
|
| 189 |
+
# Update the axis limits
|
| 190 |
+
ax.set_xlim(-padding, 1 + padding)
|
| 191 |
+
ax.set_ylim(bottom_limit, 1 + padding*2 + self.image_vertical_offset)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"Warning: Could not load background image: {e}")
|
| 194 |
+
# Also adjust fallback y-limits if loading fails
|
| 195 |
+
if gait_cycle_sections is not None and (
|
| 196 |
+
(isinstance(gait_cycle_sections, list) and len(gait_cycle_sections) > 1) or
|
| 197 |
+
(isinstance(gait_cycle_sections, np.ndarray) and len(gait_cycle_sections) > 1)
|
| 198 |
+
):
|
| 199 |
+
bottom_limit = -0.08
|
| 200 |
+
else:
|
| 201 |
+
bottom_limit = -padding*2
|
| 202 |
+
ax.set_xlim(-padding, 1 + padding)
|
| 203 |
+
ax.set_ylim(bottom_limit, 1 + padding*2)
|
| 204 |
+
|
| 205 |
+
# Draw legs
|
| 206 |
+
self._draw_leg(ax, self.joints_right, self.segments,
|
| 207 |
+
highlighted_elements, scale_factor)
|
| 208 |
+
self._draw_leg(ax, self.joints_left, self.segments,
|
| 209 |
+
highlighted_elements, scale_factor)
|
| 210 |
+
|
| 211 |
+
if title:
|
| 212 |
+
ax.set_title(title, pad=20)
|
| 213 |
+
|
| 214 |
+
# Draw phase window % if more than one section is provided
|
| 215 |
+
if gait_cycle_sections is not None:
|
| 216 |
+
# Convert to list if int
|
| 217 |
+
if isinstance(gait_cycle_sections, int):
|
| 218 |
+
sections = [gait_cycle_sections]
|
| 219 |
+
else:
|
| 220 |
+
sections = gait_cycle_sections
|
| 221 |
+
if len(sections) > 1:
|
| 222 |
+
self._draw_phase_window_percent(ax, sections, scale_factor)
|
| 223 |
+
|
| 224 |
+
ax.axis('off')
|
| 225 |
+
|
| 226 |
+
# Only apply figure-level settings if we created the figure
|
| 227 |
+
if created_new_figure:
|
| 228 |
+
fig.patch.set_alpha(0.0) # Make figure background transparent
|
| 229 |
+
fig.tight_layout()
|
| 230 |
+
return fig
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
def _draw_phase_window_percent(self, ax, gait_cycle_sections: List[int], scale_factor: float = 1.0):
|
| 234 |
+
"""Draw the phase window percent as a number under the legs"""
|
| 235 |
+
# The gait cycle is 150 points (0-149)
|
| 236 |
+
n_sections = len(gait_cycle_sections)
|
| 237 |
+
if n_sections > 1:
|
| 238 |
+
percent = n_sections / 150 * 100
|
| 239 |
+
# Place the text under the legs, centered
|
| 240 |
+
y_position = -0.04 # Slightly below the image
|
| 241 |
+
fontsize = 32 * scale_factor
|
| 242 |
+
ax.text(0.5, y_position, f"Phase window: {percent:.1f}%", ha='center', va='top', fontsize=fontsize, color='black')
|
| 243 |
+
|
| 244 |
+
def _determine_highlight_color(self, elements: List[str], location: str) -> str:
|
| 245 |
+
"""
|
| 246 |
+
Determine highlight color based on the types of sensors present for a specific location
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
elements: List of sensor names to analyze
|
| 250 |
+
location: The joint or segment name to check (e.g., 'hip', 'thigh')
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
str: Color name from color_map
|
| 254 |
+
"""
|
| 255 |
+
if not elements:
|
| 256 |
+
return 'gray'
|
| 257 |
+
|
| 258 |
+
# Filter elements to only those containing this location
|
| 259 |
+
location_elements = [e for e in elements if location in e]
|
| 260 |
+
if not location_elements:
|
| 261 |
+
return 'gray'
|
| 262 |
+
|
| 263 |
+
has_torque = any('torque' in e for e in location_elements)
|
| 264 |
+
has_angle = any('angle' in e and 'vel' not in e for e in location_elements)
|
| 265 |
+
has_velocity = any('vel' in e for e in location_elements)
|
| 266 |
+
|
| 267 |
+
if has_torque:
|
| 268 |
+
return self.color_map['torque']
|
| 269 |
+
elif has_angle and has_velocity:
|
| 270 |
+
return self.color_map['angle_velocity']
|
| 271 |
+
elif has_velocity:
|
| 272 |
+
return self.color_map['velocity']
|
| 273 |
+
elif has_angle:
|
| 274 |
+
return self.color_map['angle']
|
| 275 |
+
|
| 276 |
+
return 'gray'
|
| 277 |
+
|
| 278 |
+
def _draw_leg(self, ax, joints: dict, segments: List[Segment],
|
| 279 |
+
highlighted_elements: List[str],
|
| 280 |
+
scale_factor: float = 1.0):
|
| 281 |
+
"""Helper method to draw a single leg with highlights"""
|
| 282 |
+
suffix = '_r' if 'hip_r' in joints else '_l'
|
| 283 |
+
|
| 284 |
+
# Get all highlighted elements for this leg
|
| 285 |
+
leg_elements = [e for e in (highlighted_elements or [])
|
| 286 |
+
if e.endswith(suffix)]
|
| 287 |
+
|
| 288 |
+
# Calculate figure-relative sizes
|
| 289 |
+
fig_width, fig_height = ax.figure.get_size_inches()
|
| 290 |
+
# Scale marker size. Make it be a 12 marker size on a 3x6 figure
|
| 291 |
+
joint_size = int(np.sqrt(fig_width * fig_height) / 18 * 12 * 3.1 * scale_factor)
|
| 292 |
+
# Scale line width. Make it be a 3 line width on a 3x6 figure
|
| 293 |
+
segment_width = int(np.sqrt(fig_width * fig_height) / 18 * 3 * 3.1 * scale_factor)
|
| 294 |
+
|
| 295 |
+
# joint_size = 12
|
| 296 |
+
# segment_width = 3
|
| 297 |
+
|
| 298 |
+
# Draw segments
|
| 299 |
+
for segment in segments:
|
| 300 |
+
# Get the correct joint names with suffixes
|
| 301 |
+
joint1_name = self._get_joint_name(segment.joint1, joints)
|
| 302 |
+
joint2_name = self._get_joint_name(segment.joint2, joints)
|
| 303 |
+
|
| 304 |
+
joint1 = joints[joint1_name]
|
| 305 |
+
joint2 = joints[joint2_name]
|
| 306 |
+
x = [joint1.x, joint2.x]
|
| 307 |
+
y = [joint1.y, joint2.y]
|
| 308 |
+
|
| 309 |
+
# Get the segment name with suffix and check if it's highlighted
|
| 310 |
+
base_segment = segment.name # 'thigh', 'shank', or 'foot'
|
| 311 |
+
color = self._determine_highlight_color(leg_elements, base_segment)
|
| 312 |
+
ax.plot(x, y, '-', color=color, linewidth=segment_width, alpha=0.8)
|
| 313 |
+
|
| 314 |
+
# Draw joints (excluding toe)
|
| 315 |
+
for joint_name, joint in joints.items():
|
| 316 |
+
if not joint_name.startswith('toe'):
|
| 317 |
+
# Check if any sensor for this joint is highlighted
|
| 318 |
+
base_joint = joint_name.split('_')[0] # 'hip' from 'hip_r'
|
| 319 |
+
color = self._determine_highlight_color(leg_elements, base_joint)
|
| 320 |
+
ax.plot(joint.x, joint.y, 'o',
|
| 321 |
+
color=color if color != 'gray' else 'black',
|
| 322 |
+
markersize=joint_size)
|
| 323 |
+
|
| 324 |
+
# Draw CoP and vGRF only if mode doesn't have line under
|
| 325 |
+
if self.background_image_path.endswith('no_line_under.png'):
|
| 326 |
+
cop_name = f"cop_y{suffix}"
|
| 327 |
+
cop_h = highlighted_elements and cop_name in highlighted_elements
|
| 328 |
+
|
| 329 |
+
grf_name = f"grf_y{suffix}"
|
| 330 |
+
grf_h = highlighted_elements and grf_name in highlighted_elements
|
| 331 |
+
|
| 332 |
+
self._draw_cop_and_vgrf(
|
| 333 |
+
ax, joints, cop_highlighted=cop_h, grf_highlighted=grf_h
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def _draw_cop_and_vgrf(self, ax, joints,
|
| 337 |
+
cop_highlighted: bool, grf_highlighted: bool):
|
| 338 |
+
"""
|
| 339 |
+
Draw the center of pressure line and
|
| 340 |
+
vertical ground reaction force
|
| 341 |
+
"""
|
| 342 |
+
ankle_name = self._get_joint_name('ankle', joints)
|
| 343 |
+
toe_name = self._get_joint_name('toe', joints)
|
| 344 |
+
ankle = joints[ankle_name]
|
| 345 |
+
toe = joints[toe_name]
|
| 346 |
+
|
| 347 |
+
# Calculate CoP line parallel to foot
|
| 348 |
+
dx = toe.x - ankle.x
|
| 349 |
+
dy = toe.y - ankle.y
|
| 350 |
+
length = np.sqrt(dx**2 + dy**2)
|
| 351 |
+
|
| 352 |
+
dx_norm = dx / length
|
| 353 |
+
dy_norm = dy / length
|
| 354 |
+
|
| 355 |
+
# Calculate perpendicular offset for CoP line
|
| 356 |
+
offset_x = -dy_norm * self.cop_offset
|
| 357 |
+
offset_y = dx_norm * self.cop_offset
|
| 358 |
+
|
| 359 |
+
# Draw CoP line
|
| 360 |
+
cop_x = [ankle.x + offset_x, toe.x + offset_x]
|
| 361 |
+
cop_y = [ankle.y + offset_y, toe.y + offset_y]
|
| 362 |
+
cop_color = 'red' if cop_highlighted else 'gray'
|
| 363 |
+
ax.plot(cop_x, cop_y, '--', color=cop_color, linewidth=2, alpha=0.8)
|
| 364 |
+
|
| 365 |
+
# Draw vGRF arrow
|
| 366 |
+
arrow_x = toe.x + self.vgrf_arrow_offset
|
| 367 |
+
arrow_base_y = toe.y - 0.02
|
| 368 |
+
grf_color = 'red' if grf_highlighted else 'gray'
|
| 369 |
+
ax.arrow(arrow_x, arrow_base_y,
|
| 370 |
+
0, self.vgrf_arrow_height,
|
| 371 |
+
head_width=0.02,
|
| 372 |
+
head_length=0.04,
|
| 373 |
+
fc=grf_color, ec=grf_color,
|
| 374 |
+
alpha=0.8)
|
| 375 |
+
|
| 376 |
+
def create_sensor_config_illustrations() -> None:
|
| 377 |
+
"""Create and display multiple sensor configurations with different background modes"""
|
| 378 |
+
sensor_configs = [
|
| 379 |
+
{
|
| 380 |
+
'name': 'Hip Angle Only (Level Walking)',
|
| 381 |
+
'highlighted_elements': ['hip_angle_s_r', 'grf_y_r'],
|
| 382 |
+
'gait_cycle_sections': [0],
|
| 383 |
+
'mode': 'level_walking'
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
'name': 'Thigh Angle Only (Level Walking with Line Under)',
|
| 387 |
+
'highlighted_elements': ['thigh_angle_s_r'],
|
| 388 |
+
'gait_cycle_sections': [25, 26, 27],
|
| 389 |
+
'mode': 'level_walking_line_under'
|
| 390 |
+
},
|
| 391 |
+
{
|
| 392 |
+
'name': 'Hip Joint + Thigh Segment (Level Walking)',
|
| 393 |
+
'highlighted_elements': ['hip_angle_s_r', 'thigh_vel_s_r', 'grf_y_r'],
|
| 394 |
+
'gait_cycle_sections': [50, 51, 52],
|
| 395 |
+
'mode': 'level_walking'
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
'name': 'Multi Sensor (Decline Walking)',
|
| 399 |
+
'highlighted_elements': [
|
| 400 |
+
'hip_angle_s_r', # Joint sensor
|
| 401 |
+
'thigh_vel_s_r', # Segment sensor
|
| 402 |
+
'knee_torque_s_r', # Joint sensor
|
| 403 |
+
'shank_angle_s_r', # Segment sensor
|
| 404 |
+
'hip_angle_s_l',
|
| 405 |
+
'hip_vel_l',
|
| 406 |
+
'cop_r',
|
| 407 |
+
'grf_y_r'
|
| 408 |
+
],
|
| 409 |
+
'gait_cycle_sections': [74, 75, 76],
|
| 410 |
+
'mode': 'decline_walking'
|
| 411 |
+
}
|
| 412 |
+
]
|
| 413 |
+
|
| 414 |
+
for config in sensor_configs:
|
| 415 |
+
# Instantiate a new LegIllustration for each sensor config with the specified mode
|
| 416 |
+
illustrator = LegIllustration(mode=config.get('mode', 'level_walking'))
|
| 417 |
+
fig = illustrator.draw_illustration(
|
| 418 |
+
highlighted_elements=config['highlighted_elements'],
|
| 419 |
+
title=config['name'],
|
| 420 |
+
gait_cycle_sections=config.get('gait_cycle_sections')
|
| 421 |
+
)
|
| 422 |
+
plt.show()
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
create_sensor_config_illustrations()
|
| 426 |
+
# %%
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|