3Dvisualizer / visualization.py
Mouaador's picture
Create visualization.py
e0d5801 verified
Raw
History Blame Contribute Delete
4.33 kB
import plotly.graph_objects as go
LANDMARK_COLORS = {
"SHOULDER": "red",
"ELBOW": "green",
"WRIST": "blue",
"INDEX": "brown",
"THUMB": "magenta"
}
SEGMENT_CONNECTIONS = [
("SHOULDER", "ELBOW"),
("ELBOW", "WRIST"),
("WRIST", "INDEX"),
("WRIST", "THUMB")
]
def create_3d_visualization(df, current_frame, show_trajectories=True, show_segments=True, landmark_visibility=None):
fig = go.Figure()
if landmark_visibility is None:
landmark_visibility = {lm: True for lm in ["SHOULDER", "ELBOW", "WRIST", "INDEX", "THUMB"]}
all_coords = []
# Ajout de la légende
for lm, color in LANDMARK_COLORS.items():
fig.add_trace(go.Scatter3d(
x=[None], y=[None], z=[None],
mode='markers',
marker=dict(color=color, size=8),
name=lm,
showlegend=True
))
if show_trajectories:
for lm in ["SHOULDER", "ELBOW", "WRIST", "INDEX", "THUMB"]:
if landmark_visibility.get(lm, True):
x = df[f"{lm}_X_mean"]
y = df[f"{lm}_Y_mean"]
z = df[f"{lm}_Z_mean"]
fig.add_trace(go.Scatter3d(
x=x, y=y, z=z,
mode='lines',
line=dict(color=LANDMARK_COLORS[lm], width=4),
opacity=0.3,
showlegend=False
))
all_coords.extend(x)
all_coords.extend(y)
all_coords.extend(z)
for lm in ["SHOULDER", "ELBOW", "WRIST", "INDEX", "THUMB"]:
if landmark_visibility.get(lm, True):
try:
x = df[f"{lm}_X_mean"].iloc[current_frame]
y = df[f"{lm}_Y_mean"].iloc[current_frame]
z = df[f"{lm}_Z_mean"].iloc[current_frame]
fig.add_trace(go.Scatter3d(
x=[x], y=[y], z=[z],
mode='markers',
marker=dict(color=LANDMARK_COLORS[lm], size=8),
showlegend=False
))
all_coords.extend([x, y, z])
except IndexError:
pass
if show_segments:
for lm1, lm2 in SEGMENT_CONNECTIONS:
if landmark_visibility.get(lm1, True) and landmark_visibility.get(lm2, True):
try:
x1 = df[f"{lm1}_X_mean"].iloc[current_frame]
y1 = df[f"{lm1}_Y_mean"].iloc[current_frame]
z1 = df[f"{lm1}_Z_mean"].iloc[current_frame]
x2 = df[f"{lm2}_X_mean"].iloc[current_frame]
y2 = df[f"{lm2}_Y_mean"].iloc[current_frame]
z2 = df[f"{lm2}_Z_mean"].iloc[current_frame]
fig.add_trace(go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines',
line=dict(color='black', width=4),
showlegend=False
))
except IndexError:
pass
if all_coords:
min_val = min(all_coords)
max_val = max(all_coords)
padding = (max_val - min_val) * 0.1
fig.update_layout(
scene=dict(
xaxis=dict(range=[min_val - padding, max_val + padding]),
yaxis=dict(range=[min_val - padding, max_val + padding]),
zaxis=dict(range=[min_val - padding, max_val + padding]),
aspectmode='cube'
),
margin=dict(l=0, r=0, b=0, t=0),
height=800,
width=1200,
legend=dict(
x=0.05,
y=0.95,
bgcolor='rgba(0,0,0,0.5)',
font=dict(color='white')
)
)
else:
fig.update_layout(
scene=dict(
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
zaxis=dict(range=[-1, 1]),
aspectmode='cube'
),
margin=dict(l=0, r=0, b=0, t=0),
height=800,
width=1200,
legend=dict(
x=0.05,
y=0.95,
bgcolor='rgba(0,0,0,0.5)',
font=dict(color='white')
)
)
return fig