|
|
import json |
|
|
import html |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
from functools import lru_cache |
|
|
|
|
|
import gradio as gr |
|
|
import plotly.graph_objects as go |
|
|
import plotly.io as pio |
|
|
|
|
|
METRIC_LABELS = { |
|
|
"x_cm": "X (cm)", |
|
|
"y_cm": "Y (cm)", |
|
|
"z_cm": "Z (cm)", |
|
|
"yaw_deg": "Yaw (°)", |
|
|
"pitch_deg": "Pitch (°)", |
|
|
"roll_deg": "Roll (°)", |
|
|
} |
|
|
|
|
|
PLOT_GRID = [ |
|
|
["x_cm", "y_cm", "z_cm"], |
|
|
["yaw_deg", "pitch_deg", "roll_deg"], |
|
|
] |
|
|
|
|
|
PLOT_ORDER = [metric for row in PLOT_GRID for metric in row] |
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
:root, .gradio-container, body { |
|
|
background-color: #050a18 !important; |
|
|
color: #f8fafc !important; |
|
|
font-family: 'Inter', 'Segoe UI', system-ui, sans-serif; |
|
|
} |
|
|
.side-panel { |
|
|
background: #0f172a; |
|
|
padding: 20px; |
|
|
border-radius: 18px; |
|
|
border: 1px solid #1f2b47; |
|
|
min-height: 100%; |
|
|
} |
|
|
.stats-card ul { |
|
|
list-style: none; |
|
|
padding: 0; |
|
|
margin: 0; |
|
|
font-size: 0.92rem; |
|
|
} |
|
|
.stats-card li { |
|
|
margin-bottom: 10px; |
|
|
color: #e2e8f0; |
|
|
} |
|
|
.stats-card span { |
|
|
display: inline-block; |
|
|
margin-right: 6px; |
|
|
color: #7dd3fc; |
|
|
} |
|
|
.main-panel { |
|
|
padding-top: 8px; |
|
|
} |
|
|
.instruction-card { |
|
|
background: #0f172a; |
|
|
padding: 18px 20px; |
|
|
border-radius: 18px; |
|
|
border: 1px solid #1f2b47; |
|
|
} |
|
|
.instruction-label { |
|
|
font-size: 0.75rem; |
|
|
letter-spacing: 0.12em; |
|
|
text-transform: uppercase; |
|
|
color: #94a3b8; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
.instruction-text { |
|
|
font-size: 1.1rem; |
|
|
line-height: 1.5; |
|
|
} |
|
|
.video-card { |
|
|
background: #0f172a; |
|
|
border: 1px solid #1f2b47; |
|
|
border-radius: 18px; |
|
|
padding: 18px 20px; |
|
|
margin-top: 18px; |
|
|
} |
|
|
.video-title { |
|
|
font-size: 0.78rem; |
|
|
text-transform: uppercase; |
|
|
letter-spacing: 0.18em; |
|
|
color: #94a3b8; |
|
|
margin-bottom: 8px; |
|
|
} |
|
|
.video-panel video { |
|
|
border-radius: 12px; |
|
|
border: 1px solid #1f2b47; |
|
|
background: #030712; |
|
|
} |
|
|
.download-button button { |
|
|
border-radius: 999px; |
|
|
border: 1px solid #334155; |
|
|
background: #1e293b; |
|
|
color: #f8fafc; |
|
|
font-size: 0.85rem; |
|
|
padding: 8px 24px; |
|
|
} |
|
|
.download-button button:hover { |
|
|
border-color: #67e8f9; |
|
|
color: #67e8f9; |
|
|
} |
|
|
.plots-wrap { |
|
|
margin-top: 18px; |
|
|
} |
|
|
.plots-wrap .gr-row { |
|
|
gap: 16px; |
|
|
} |
|
|
.plot-html { |
|
|
background: #111a2c; |
|
|
border-radius: 12px; |
|
|
padding: 10px; |
|
|
border: 1px solid #1f2b47; |
|
|
min-height: 320px; |
|
|
} |
|
|
.plot-html iframe { |
|
|
width: 100%; |
|
|
height: 300px; |
|
|
border: none; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
def get_data_dir(): |
|
|
"""Get data directory path.""" |
|
|
try: |
|
|
return Path(__file__).parent / "data" |
|
|
except: |
|
|
return Path("data") |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def load_data(): |
|
|
"""Load all data files.""" |
|
|
data_dir = get_data_dir() |
|
|
metadata_path = data_dir / "metadata.json" |
|
|
end_effector_path = data_dir / "end_effector.json" |
|
|
hands_2d_path = data_dir / "hands_2d.json" |
|
|
|
|
|
metadata = {} |
|
|
end_effector = {} |
|
|
hands_2d = {} |
|
|
|
|
|
if metadata_path.exists(): |
|
|
with open(metadata_path, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
if end_effector_path.exists(): |
|
|
with open(end_effector_path, 'r') as f: |
|
|
end_effector = json.load(f) |
|
|
|
|
|
if hands_2d_path.exists(): |
|
|
with open(hands_2d_path, 'r') as f: |
|
|
hands_2d = json.load(f) |
|
|
|
|
|
return metadata, end_effector, hands_2d |
|
|
|
|
|
|
|
|
def build_state_dataframe(metadata: dict, end_effector: dict, hand: str = "left"): |
|
|
"""Build state dataframe from JSON data.""" |
|
|
fps = metadata.get('fps', 60) |
|
|
|
|
|
try: |
|
|
frame_keys = sorted([int(k) for k in end_effector.keys() if str(k).isdigit()]) |
|
|
except: |
|
|
frame_keys = [] |
|
|
|
|
|
timestamps = [] |
|
|
state_data = { |
|
|
'wrist_x_cm': [], |
|
|
'wrist_y_cm': [], |
|
|
'wrist_z_cm': [], |
|
|
'wrist_yaw_deg': [], |
|
|
'wrist_pitch_deg': [], |
|
|
'wrist_roll_deg': [], |
|
|
} |
|
|
|
|
|
for frame_idx in frame_keys: |
|
|
frame_key = str(frame_idx) |
|
|
t = frame_idx / fps |
|
|
timestamps.append(t) |
|
|
|
|
|
ee_data = end_effector.get(frame_key, {}) or {} |
|
|
hand_data = ee_data.get(hand + "_hand") |
|
|
|
|
|
if hand_data and isinstance(hand_data, dict): |
|
|
pose = hand_data.get('pose_6dof') |
|
|
if pose and len(pose) >= 6: |
|
|
state_data['wrist_x_cm'].append(pose[0] * 100) |
|
|
state_data['wrist_y_cm'].append(pose[1] * 100) |
|
|
state_data['wrist_z_cm'].append(pose[2] * 100) |
|
|
state_data['wrist_roll_deg'].append(pose[3] * 57.3) |
|
|
state_data['wrist_pitch_deg'].append(pose[4] * 57.3) |
|
|
state_data['wrist_yaw_deg'].append(pose[5] * 57.3) |
|
|
else: |
|
|
for k in state_data: |
|
|
state_data[k].append(None) |
|
|
else: |
|
|
for k in state_data: |
|
|
state_data[k].append(None) |
|
|
|
|
|
return timestamps, state_data |
|
|
|
|
|
|
|
|
def build_plot_fig(timestamps: List[float], state_data: Dict, metric: str) -> go.Figure: |
|
|
"""Build Plotly figure for a metric.""" |
|
|
col_name = f"wrist_{metric}" |
|
|
if col_name not in state_data: |
|
|
return go.Figure() |
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=timestamps, |
|
|
y=state_data[col_name], |
|
|
mode="lines", |
|
|
name="Wrist", |
|
|
) |
|
|
) |
|
|
fig.update_layout( |
|
|
margin=dict(l=20, r=20, t=30, b=20), |
|
|
height=250, |
|
|
template="plotly_dark", |
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
|
|
xaxis_title="Time (s)", |
|
|
yaxis_title=METRIC_LABELS[metric], |
|
|
) |
|
|
fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
|
|
fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
|
|
return fig |
|
|
|
|
|
|
|
|
def build_plot_html(timestamps: List[float], state_data: Dict, metric: str) -> str: |
|
|
"""Build Plotly HTML for a metric.""" |
|
|
fig = build_plot_fig(timestamps, state_data, metric) |
|
|
return pio.to_html(fig, include_plotlyjs="cdn", full_html=False) |
|
|
|
|
|
|
|
|
def format_instruction_html(text: str) -> str: |
|
|
safe_text = html.escape(text) |
|
|
return ( |
|
|
'<div class="instruction-card">' |
|
|
'<p class="instruction-label">Language Instruction</p>' |
|
|
f'<p class="instruction-text">{safe_text}</p>' |
|
|
"</div>" |
|
|
) |
|
|
|
|
|
|
|
|
def build_interface(): |
|
|
"""Build Gradio interface.""" |
|
|
metadata, end_effector, hands_2d = load_data() |
|
|
|
|
|
total_frames = len(metadata.get('poses', [])) |
|
|
fps = metadata.get('fps', 60) |
|
|
hand_detection_rate = len(hands_2d) / max(1, total_frames) * 100 if total_frames > 0 else 0 |
|
|
|
|
|
left_poses = sum(1 for f in end_effector.values() if f and isinstance(f, dict) and f.get('left_hand')) |
|
|
right_poses = sum(1 for f in end_effector.values() if f and isinstance(f, dict) and f.get('right_hand')) |
|
|
|
|
|
video_path = get_data_dir() / "video.mp4" |
|
|
|
|
|
|
|
|
left_timestamps, left_state = build_state_dataframe(metadata, end_effector, "left") |
|
|
left_figs = {metric: build_plot_html(left_timestamps, left_state, metric) for metric in METRIC_LABELS.keys()} |
|
|
|
|
|
|
|
|
right_timestamps, right_state = build_state_dataframe(metadata, end_effector, "right") |
|
|
right_figs = {metric: build_plot_html(right_timestamps, right_state, metric) for metric in METRIC_LABELS.keys()} |
|
|
|
|
|
stats_html = f""" |
|
|
<div class="stats-card"> |
|
|
<ul> |
|
|
<li><span>Number of samples/frames:</span> {total_frames:,}</li> |
|
|
<li><span>Hand detection rate:</span> {hand_detection_rate:.1f}%</li> |
|
|
<li><span>Left hand poses:</span> {left_poses}</li> |
|
|
<li><span>Right hand poses:</span> {right_poses}</li> |
|
|
<li><span>Frames per second:</span> {fps:.1f}</li> |
|
|
</ul> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
instruction_text = "LiDAR-based egocentric hand tracking for robot training data" |
|
|
|
|
|
theme = gr.themes.Soft( |
|
|
primary_hue="cyan", secondary_hue="blue", neutral_hue="slate" |
|
|
).set( |
|
|
body_background_fill="#0c1424", |
|
|
body_text_color="#f8fafc", |
|
|
block_background_fill="#111a2c", |
|
|
block_title_text_color="#f8fafc", |
|
|
input_background_fill="#151f33", |
|
|
border_color_primary="#1f2b47", |
|
|
shadow_drop="none", |
|
|
) |
|
|
|
|
|
with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo: |
|
|
gr.Markdown("# 🤖 Dynamic Intelligence - Human Demo Visualizer") |
|
|
gr.Markdown( |
|
|
"Egocentric hand tracking dataset for humanoid robot training. " |
|
|
"Pipeline: iPhone LiDAR → MediaPipe → 6DoF End-Effector → Robot Training Data" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]): |
|
|
gr.HTML(stats_html) |
|
|
with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]): |
|
|
instruction_box = gr.HTML( |
|
|
format_instruction_html(instruction_text), |
|
|
label="Language Instruction", |
|
|
) |
|
|
with gr.Column(elem_classes=["video-card"]): |
|
|
gr.HTML('<div class="video-title">RGB Video</div>') |
|
|
video = gr.Video( |
|
|
height=360, |
|
|
value=str(video_path) if video_path.exists() else None, |
|
|
elem_classes=["video-panel"], |
|
|
show_label=False, |
|
|
show_download_button=False, |
|
|
) |
|
|
download_button = gr.DownloadButton( |
|
|
label="Download", |
|
|
value=str(video_path) if video_path.exists() else None, |
|
|
elem_classes=["download-button"], |
|
|
) |
|
|
|
|
|
plot_outputs_left = [] |
|
|
gr.Markdown("### Left Hand Trajectories", elem_classes=["plots-title"]) |
|
|
with gr.Column(elem_classes=["plots-wrap"]): |
|
|
for row in PLOT_GRID: |
|
|
with gr.Row(): |
|
|
for metric in row: |
|
|
plot = gr.HTML(value=left_figs[metric], elem_classes=["plot-html"]) |
|
|
plot_outputs_left.append(plot) |
|
|
|
|
|
plot_outputs_right = [] |
|
|
gr.Markdown("### Right Hand Trajectories", elem_classes=["plots-title"]) |
|
|
with gr.Column(elem_classes=["plots-wrap"]): |
|
|
for row in PLOT_GRID: |
|
|
with gr.Row(): |
|
|
for metric in row: |
|
|
plot = gr.HTML(value=right_figs[metric], elem_classes=["plot-html"]) |
|
|
plot_outputs_right.append(plot) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
def main(): |
|
|
demo = build_interface() |
|
|
demo.queue().launch(show_api=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
demo = build_interface() |
|
|
|