Spaces:
Running
Running
Colin Leong
commited on
Commit
·
84dfc7c
1
Parent(s):
869eec5
Add YouTube-ASL filtering, and ability to download points_dict and components list
Browse files
app.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
| 3 |
import numpy as np
|
| 4 |
from pose_format import Pose
|
|
|
|
| 5 |
from pose_format.pose_visualizer import PoseVisualizer
|
| 6 |
-
from pathlib import Path
|
| 7 |
from pyzstd import decompress
|
| 8 |
from PIL import Image
|
| 9 |
import mediapipe as mp
|
|
@@ -15,39 +19,47 @@ FACEMESH_CONTOURS_POINTS = [
|
|
| 15 |
set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
|
| 16 |
)
|
| 17 |
]
|
|
|
|
| 18 |
|
| 19 |
-
def
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
p2=("POSE_LANDMARKS", "LEFT_SHOULDER"),
|
| 24 |
-
)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
p1=("pose_keypoints_2d", "RShoulder"), p2=("pose_keypoints_2d", "LShoulder")
|
| 34 |
-
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
# @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
|
|
@@ -60,7 +72,7 @@ def load_pose(uploaded_file: UploadedFile) -> Pose:
|
|
| 60 |
return Pose.read(uploaded_file.read())
|
| 61 |
|
| 62 |
|
| 63 |
-
@st.cache_data(hash_funcs={Pose: lambda p: np.
|
| 64 |
def get_pose_frames(pose: Pose, transparency: bool = False):
|
| 65 |
v = PoseVisualizer(pose)
|
| 66 |
frames = [frame_data for frame_data in v.draw()]
|
|
@@ -73,7 +85,13 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
|
|
| 73 |
return frames, images
|
| 74 |
|
| 75 |
|
| 76 |
-
def get_pose_gif(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if fps is not None:
|
| 78 |
pose.body.fps = fps
|
| 79 |
v = PoseVisualizer(pose)
|
|
@@ -89,37 +107,42 @@ st.write(
|
|
| 89 |
st.write(
|
| 90 |
"I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
|
| 91 |
)
|
| 92 |
-
st.write(
|
|
|
|
|
|
|
| 93 |
uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
|
| 94 |
|
| 95 |
|
| 96 |
if uploaded_file is not None:
|
| 97 |
with st.spinner(f"Loading {uploaded_file.name}"):
|
| 98 |
pose = load_pose(uploaded_file)
|
|
|
|
| 99 |
frames, images = get_pose_frames(pose=pose)
|
| 100 |
st.success("Done loading!")
|
| 101 |
-
|
| 102 |
st.write("### File Info")
|
| 103 |
with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
|
| 104 |
st.write(pose.header)
|
| 105 |
|
| 106 |
st.write(f"### Selection")
|
| 107 |
component_selection = st.radio(
|
| 108 |
-
"How to select components?", options=
|
| 109 |
)
|
| 110 |
|
| 111 |
component_names = [c.name for c in pose.header.components]
|
| 112 |
chosen_component_names = []
|
| 113 |
points_dict = {}
|
| 114 |
-
|
| 115 |
|
| 116 |
if component_selection == "manual":
|
| 117 |
-
|
| 118 |
|
| 119 |
chosen_component_names = st.pills(
|
| 120 |
-
"Select components to visualize",
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
-
|
| 123 |
for component in pose.header.components:
|
| 124 |
if component.name in chosen_component_names:
|
| 125 |
with st.expander(f"Points for {component.name}"):
|
|
@@ -128,32 +151,118 @@ if uploaded_file is not None:
|
|
| 128 |
options=component.points,
|
| 129 |
default=component.points,
|
| 130 |
)
|
| 131 |
-
if
|
|
|
|
|
|
|
| 132 |
points_dict[component.name] = selected_points
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
| 136 |
elif component_selection == "signclip":
|
| 137 |
st.write("Selected landmarks used for SignCLIP.")
|
| 138 |
-
chosen_component_names = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
| 140 |
-
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# Filter button logic
|
| 144 |
-
|
| 145 |
st.write("### Filter .pose File")
|
| 146 |
filtered = st.button("Apply Filter!")
|
| 147 |
if filtered:
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
st.session_state.filtered_pose = pose
|
| 152 |
|
| 153 |
-
filtered_pose = st.session_state.get(
|
| 154 |
if filtered_pose:
|
| 155 |
-
filtered_pose = st.session_state.get(
|
| 156 |
-
st.write(
|
| 157 |
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
| 158 |
with st.expander("Show header"):
|
| 159 |
st.write(filtered_pose.header)
|
|
@@ -170,12 +279,20 @@ if uploaded_file is not None:
|
|
| 170 |
pose.write(f)
|
| 171 |
|
| 172 |
with pose_file_out.open("rb") as f:
|
| 173 |
-
st.download_button(
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
st.write("### Visualization")
|
| 177 |
-
step = st.select_slider(
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
start_frame, end_frame = st.slider(
|
| 180 |
"Select Frame Range",
|
| 181 |
0,
|
|
@@ -185,6 +302,13 @@ if uploaded_file is not None:
|
|
| 185 |
# Visualization button logic
|
| 186 |
if st.button("Visualize"):
|
| 187 |
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Optional, List, Tuple
|
| 4 |
+
from collections import defaultdict
|
| 5 |
import streamlit as st
|
| 6 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
| 7 |
import numpy as np
|
| 8 |
from pose_format import Pose
|
| 9 |
+
from pose_format.utils.generic import pose_hide_legs, reduce_holistic
|
| 10 |
from pose_format.pose_visualizer import PoseVisualizer
|
|
|
|
| 11 |
from pyzstd import decompress
|
| 12 |
from PIL import Image
|
| 13 |
import mediapipe as mp
|
|
|
|
| 19 |
set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
|
| 20 |
)
|
| 21 |
]
|
| 22 |
+
COMPONENT_SELECTION_METHODS = ["manual", "signclip", "youtube-asl", "reduce_holistic"]
|
| 23 |
|
| 24 |
+
def download_json(data):
|
| 25 |
+
json_data = json.dumps(data)
|
| 26 |
+
json_bytes = json_data.encode('utf-8')
|
| 27 |
+
return json_bytes
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def get_points_dict_and_components_with_index_list(
|
| 30 |
+
pose: Pose, landmark_indices: List[int], components_to_include: Optional[List[str]]
|
| 31 |
+
) -> Tuple[List[str], Dict[str, List[str]]]:
|
| 32 |
+
"""Used to get components/points if you only have a list of indices,
|
| 33 |
+
e.g. listed in a research paper like YouTube-ASL.
|
| 34 |
+
If you want to also explicitly specify component names, you can.
|
| 35 |
+
So for example, to get the two hands and the nose you could do the following:
|
| 36 |
+
c_names, points_dict = get_points_dict_and_components_with_index_list(pose,
|
| 37 |
+
landmark_indices=[0] # which is "NOSE" within POSE_LANDMARKS components
|
| 38 |
+
components_to_include=["LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS]
|
| 39 |
)
|
| 40 |
|
| 41 |
+
then you can just use get_components
|
| 42 |
+
filtered_pose = pose.get_components(c_names, points_dict)
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
"""
|
| 45 |
+
components_to_get = []
|
| 46 |
+
points_dict = defaultdict(list)
|
| 47 |
|
| 48 |
+
for c in pose.header.components:
|
| 49 |
+
for point_name in c.points:
|
| 50 |
+
point_index = pose.header.get_point_index(c.name, point_name)
|
| 51 |
+
if point_index in landmark_indices:
|
| 52 |
+
components_to_get.append(c.name)
|
| 53 |
+
points_dict[c.name].append(point_name)
|
| 54 |
+
# print(f"Point with index {point_index} has name {c.name}:{point_name}")
|
| 55 |
+
|
| 56 |
+
if components_to_include:
|
| 57 |
+
components_to_get.extend(components_to_include)
|
| 58 |
+
components_to_get = list(set(components_to_get))
|
| 59 |
+
# print("*********************")
|
| 60 |
+
# print(components_to_get)
|
| 61 |
+
# print(points_dict)
|
| 62 |
+
return components_to_get, points_dict
|
| 63 |
|
| 64 |
|
| 65 |
# @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
|
|
|
|
| 72 |
return Pose.read(uploaded_file.read())
|
| 73 |
|
| 74 |
|
| 75 |
+
@st.cache_data(hash_funcs={Pose: lambda p: np.asarray(p.body.data.data)})
|
| 76 |
def get_pose_frames(pose: Pose, transparency: bool = False):
|
| 77 |
v = PoseVisualizer(pose)
|
| 78 |
frames = [frame_data for frame_data in v.draw()]
|
|
|
|
| 85 |
return frames, images
|
| 86 |
|
| 87 |
|
| 88 |
+
def get_pose_gif(
|
| 89 |
+
pose: Pose,
|
| 90 |
+
step: int = 1,
|
| 91 |
+
start_frame: Optional[int] = None,
|
| 92 |
+
end_frame: Optional[int] = None,
|
| 93 |
+
fps: Optional[float] = None,
|
| 94 |
+
):
|
| 95 |
if fps is not None:
|
| 96 |
pose.body.fps = fps
|
| 97 |
v = PoseVisualizer(pose)
|
|
|
|
| 107 |
st.write(
|
| 108 |
"I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
|
| 109 |
)
|
| 110 |
+
st.write(
|
| 111 |
+
"If you need a .pose file, here's one of [me doing a self-introduction](https://drive.google.com/file/d/1_L5sYVhONDBABuTmQUvjsl94LbFqzEyP/view?usp=sharing), and one of [me signing ASL 'HOUSE'](https://drive.google.com/file/d/1uggYqLyTA4XdDWaWsS9w5hKaPwW86IF_/view?usp=sharing)"
|
| 112 |
+
)
|
| 113 |
uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
|
| 114 |
|
| 115 |
|
| 116 |
if uploaded_file is not None:
|
| 117 |
with st.spinner(f"Loading {uploaded_file.name}"):
|
| 118 |
pose = load_pose(uploaded_file)
|
| 119 |
+
# st.write(pose.body.data.shape)
|
| 120 |
frames, images = get_pose_frames(pose=pose)
|
| 121 |
st.success("Done loading!")
|
| 122 |
+
|
| 123 |
st.write("### File Info")
|
| 124 |
with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
|
| 125 |
st.write(pose.header)
|
| 126 |
|
| 127 |
st.write(f"### Selection")
|
| 128 |
component_selection = st.radio(
|
| 129 |
+
"How to select components?", options=COMPONENT_SELECTION_METHODS
|
| 130 |
)
|
| 131 |
|
| 132 |
component_names = [c.name for c in pose.header.components]
|
| 133 |
chosen_component_names = []
|
| 134 |
points_dict = {}
|
| 135 |
+
HIDE_LEGS = False
|
| 136 |
|
| 137 |
if component_selection == "manual":
|
|
|
|
| 138 |
|
| 139 |
chosen_component_names = st.pills(
|
| 140 |
+
"Select components to visualize",
|
| 141 |
+
options=component_names,
|
| 142 |
+
default=component_names,
|
| 143 |
+
selection_mode="multi",
|
| 144 |
)
|
| 145 |
+
|
| 146 |
for component in pose.header.components:
|
| 147 |
if component.name in chosen_component_names:
|
| 148 |
with st.expander(f"Points for {component.name}"):
|
|
|
|
| 151 |
options=component.points,
|
| 152 |
default=component.points,
|
| 153 |
)
|
| 154 |
+
if (
|
| 155 |
+
selected_points != component.points
|
| 156 |
+
): # Only add entry if not all points are selected
|
| 157 |
points_dict[component.name] = selected_points
|
|
|
|
|
|
|
| 158 |
|
| 159 |
elif component_selection == "signclip":
|
| 160 |
st.write("Selected landmarks used for SignCLIP.")
|
| 161 |
+
chosen_component_names = [
|
| 162 |
+
"POSE_LANDMARKS",
|
| 163 |
+
"FACE_LANDMARKS",
|
| 164 |
+
"LEFT_HAND_LANDMARKS",
|
| 165 |
+
"RIGHT_HAND_LANDMARKS",
|
| 166 |
+
]
|
| 167 |
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
|
|
|
| 168 |
|
| 169 |
+
elif component_selection == "youtube-asl":
|
| 170 |
+
st.write("Selected landmarks used for SignCLIP.")
|
| 171 |
+
# https://arxiv.org/pdf/2306.15162
|
| 172 |
+
# For each hand, we use all 21 landmark points.
|
| 173 |
+
# Colin: So that's
|
| 174 |
+
# For the pose, we use 6 landmark points, for the shoulders, elbows and hips
|
| 175 |
+
# These are indices 11, 12, 13, 14, 23, 24
|
| 176 |
+
# For the face, we use 37 landmark points, from the eyes, eyebrows, lips, and face outline.
|
| 177 |
+
# These are indices 0, 4, 13, 14, 17, 33, 37, 39, 46, 52, 55, 61, 64, 81, 82, 93, 133, 151, 152, 159, 172, 178,
|
| 178 |
+
# 181, 263, 269, 276, 282, 285, 291, 294, 311, 323, 362, 386, 397, 468, 473.
|
| 179 |
+
# Colin: note that these are with refine_face_landmarks on, and are relative to the component itself. Working it all out the result is:
|
| 180 |
+
components=['POSE_LANDMARKS', 'FACE_LANDMARKS', 'LEFT_HAND_LANDMARKS', 'RIGHT_HAND_LANDMARKS']
|
| 181 |
+
points_dict={
|
| 182 |
+
"POSE_LANDMARKS": [
|
| 183 |
+
"LEFT_SHOULDER",
|
| 184 |
+
"RIGHT_SHOULDER",
|
| 185 |
+
"LEFT_HIP",
|
| 186 |
+
"RIGHT_HIP",
|
| 187 |
+
"LEFT_ELBOW",
|
| 188 |
+
"RIGHT_ELBOW"
|
| 189 |
+
],
|
| 190 |
+
"FACE_LANDMARKS": [
|
| 191 |
+
"0",
|
| 192 |
+
"4",
|
| 193 |
+
"13",
|
| 194 |
+
"14",
|
| 195 |
+
"17",
|
| 196 |
+
"33",
|
| 197 |
+
"37",
|
| 198 |
+
"39",
|
| 199 |
+
"46",
|
| 200 |
+
"52",
|
| 201 |
+
"55",
|
| 202 |
+
"61",
|
| 203 |
+
"64",
|
| 204 |
+
"81",
|
| 205 |
+
"82",
|
| 206 |
+
"93",
|
| 207 |
+
"133",
|
| 208 |
+
"151",
|
| 209 |
+
"152",
|
| 210 |
+
"159",
|
| 211 |
+
"172",
|
| 212 |
+
"178",
|
| 213 |
+
"181",
|
| 214 |
+
"263",
|
| 215 |
+
"269",
|
| 216 |
+
"276",
|
| 217 |
+
"282",
|
| 218 |
+
"285",
|
| 219 |
+
"291",
|
| 220 |
+
"294",
|
| 221 |
+
"311",
|
| 222 |
+
"323",
|
| 223 |
+
"362",
|
| 224 |
+
"386",
|
| 225 |
+
"397",
|
| 226 |
+
"468", # 468 only exists with the refine_face_landmarks option on MediaPipe
|
| 227 |
+
"473", # 473 only exists with the refine_face_landmarks option on MediaPipe
|
| 228 |
+
]
|
| 229 |
+
}
|
| 230 |
|
| 231 |
# Filter button logic
|
| 232 |
+
# Filter section
|
| 233 |
st.write("### Filter .pose File")
|
| 234 |
filtered = st.button("Apply Filter!")
|
| 235 |
if filtered:
|
| 236 |
+
st.write(f"Filtering strategy: {component_selection}")
|
| 237 |
+
|
| 238 |
+
if component_selection == "reduce_holistic":
|
| 239 |
+
# st.write(f"reduce_holistic:")
|
| 240 |
+
pose = reduce_holistic(pose)
|
| 241 |
+
st.write("Used pose_format.reduce_holistic")
|
| 242 |
+
else:
|
| 243 |
+
pose = pose.get_components(components=chosen_component_names, points=points_dict if points_dict else None
|
| 244 |
+
)
|
| 245 |
+
with st.expander("Show component list and points dict used for get_components"):
|
| 246 |
+
st.write("##### Component names")
|
| 247 |
+
st.write(chosen_component_names)
|
| 248 |
+
st.write("##### Points dict")
|
| 249 |
+
st.write(points_dict)
|
| 250 |
+
|
| 251 |
+
with st.expander("How to replicate in pose-format"):
|
| 252 |
+
st.write("##### Usage:")
|
| 253 |
+
st.write("How to achieve the same result with pose-format library")
|
| 254 |
+
# points_dict_str = json.dumps(points_dict, indent=4)
|
| 255 |
+
usage_string = f"components={chosen_component_names}\npoints_dict={points_dict}\npose = pose.get_components(components=components, points=points_dict)"
|
| 256 |
+
st.code(usage_string)
|
| 257 |
+
|
| 258 |
+
if HIDE_LEGS:
|
| 259 |
+
pose = pose_hide_legs(pose, remove=True)
|
| 260 |
st.session_state.filtered_pose = pose
|
| 261 |
|
| 262 |
+
filtered_pose = st.session_state.get("filtered_pose", pose)
|
| 263 |
if filtered_pose:
|
| 264 |
+
filtered_pose = st.session_state.get("filtered_pose", pose)
|
| 265 |
+
st.write("#### Filtered .pose file")
|
| 266 |
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
| 267 |
with st.expander("Show header"):
|
| 268 |
st.write(filtered_pose.header)
|
|
|
|
| 279 |
pose.write(f)
|
| 280 |
|
| 281 |
with pose_file_out.open("rb") as f:
|
| 282 |
+
st.download_button(
|
| 283 |
+
"Download Filtered Pose", f, file_name=pose_file_out.name
|
| 284 |
+
)
|
| 285 |
|
|
|
|
| 286 |
st.write("### Visualization")
|
| 287 |
+
step = st.select_slider(
|
| 288 |
+
"Step value to select every nth image", list(range(1, len(frames))), value=1
|
| 289 |
+
)
|
| 290 |
+
fps = st.slider(
|
| 291 |
+
"FPS for visualization",
|
| 292 |
+
min_value=1.0,
|
| 293 |
+
max_value=filtered_pose.body.fps,
|
| 294 |
+
value=filtered_pose.body.fps,
|
| 295 |
+
)
|
| 296 |
start_frame, end_frame = st.slider(
|
| 297 |
"Select Frame Range",
|
| 298 |
0,
|
|
|
|
| 302 |
# Visualization button logic
|
| 303 |
if st.button("Visualize"):
|
| 304 |
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
| 305 |
+
|
| 306 |
+
pose_bytes = get_pose_gif(
|
| 307 |
+
pose=filtered_pose,
|
| 308 |
+
step=step,
|
| 309 |
+
start_frame=start_frame,
|
| 310 |
+
end_frame=end_frame,
|
| 311 |
+
fps=fps,
|
| 312 |
+
)
|
| 313 |
+
if pose_bytes is not None:
|
| 314 |
+
st.image(pose_bytes)
|