msunbot1's picture
Initial Space
218f462 verified
"""Gradio Space for browsing Ego2Robot episodes."""
import gradio as gr
import numpy as np
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from io import BytesIO
from PIL import Image
# Download dataset files
REPO_ID = "msunbot1/ego2robot-factory-episodes"
def load_episode(episode_idx):
"""Load episode from HF Hub."""
filename = f"data/episode_{episode_idx:06d}.npz"
try:
file_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
repo_type="dataset"
)
return np.load(file_path)
except Exception as e:
return None
def load_metadata():
"""Load dataset metadata."""
try:
info_path = hf_hub_download(
repo_id=REPO_ID,
filename="meta/info.json",
repo_type="dataset"
)
with open(info_path) as f:
return json.load(f)
except:
return {"total_episodes": 50, "total_frames": 1800, "fps": 6}
# Load metadata
metadata = load_metadata()
def visualize_episode(episode_idx, frame_idx):
"""Create visualization for a specific frame."""
ep = load_episode(episode_idx)
if ep is None:
return None, "Episode not found"
num_frames = len(ep['frame_index'])
frame_idx = min(frame_idx, num_frames - 1)
# Get frame data
img = ep['observation.images.top'][frame_idx]
bbox = ep['observation.state'][frame_idx]
action = ep['action'][frame_idx]
# Create figure
fig, ax = plt.subplots(figsize=(10, 6))
ax.imshow(img)
# Draw hand bbox if visible
if bbox[2] > 0:
x_min, y_min, x_max, y_max = bbox
x_min *= 640
y_min *= 360
x_max *= 640
y_max *= 360
rect = patches.Rectangle(
(x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=3,
edgecolor='red',
facecolor='none'
)
ax.add_patch(rect)
# Add action arrow
center_x = (x_min + x_max) / 2
center_y = (y_min + y_max) / 2
dx = action[0] * 100 # Scale for visibility
dy = action[1] * 100
ax.arrow(center_x, center_y, dx, dy,
head_width=20, head_length=20,
fc='yellow', ec='yellow', linewidth=2)
ax.set_title(f"Episode {episode_idx} | Frame {frame_idx}/{num_frames-1}\n"
f"Action: [{action[0]:.3f}, {action[1]:.3f}]",
fontsize=12, pad=10)
ax.axis('off')
# Save to buffer
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
buf.seek(0)
plt.close()
# Episode info
info_text = f"""
**Episode {episode_idx} Information:**
- Total Frames: {num_frames}
- Current Frame: {frame_idx}
- Timestamp: {ep['timestamp'][frame_idx]:.2f}s
- Hand Visible: {'Yes' if bbox[2] > 0 else 'No'}
- Action Magnitude: {np.linalg.norm(action):.3f}
"""
return Image.open(buf), info_text
def get_episode_overview(episode_idx):
"""Get overview visualization of entire episode."""
ep = load_episode(episode_idx)
if ep is None:
return None
num_frames = len(ep['frame_index'])
# Sample 8 frames
indices = np.linspace(0, num_frames-1, 8, dtype=int)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()
for i, idx in enumerate(indices):
ax = axes[i]
img = ep['observation.images.top'][idx]
bbox = ep['observation.state'][idx]
action = ep['action'][idx]
ax.imshow(img)
# Draw bbox
if bbox[2] > 0:
x_min, y_min, x_max, y_max = bbox * [640, 360, 640, 360]
rect = patches.Rectangle(
(x_min, y_min), x_max - x_min, y_max - y_min,
linewidth=2, edgecolor='red', facecolor='none'
)
ax.add_patch(rect)
ax.set_title(f"Frame {idx}", fontsize=9)
ax.axis('off')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
buf.seek(0)
plt.close()
return Image.open(buf)
# Create Gradio interface
with gr.Blocks(title="Ego2Robot Episode Browser") as demo:
gr.Markdown("# πŸ€– Ego2Robot: Factory Episode Browser")
gr.Markdown(f"""
Browse 50 episodes of factory manipulation tasks from the [Ego2Robot dataset](https://huggingface.co/datasets/{REPO_ID}).
**Dataset Stats:**
- Episodes: {metadata['total_episodes']}
- Total Frames: {metadata['total_frames']}
- FPS: {metadata['fps']}
""")
with gr.Row():
with gr.Column():
episode_slider = gr.Slider(
minimum=0,
maximum=metadata['total_episodes']-1,
step=1,
value=0,
label="Episode"
)
frame_slider = gr.Slider(
minimum=0,
maximum=35,
step=1,
value=0,
label="Frame"
)
visualize_btn = gr.Button("πŸ” Visualize Frame", variant="primary")
overview_btn = gr.Button("πŸ“Š Episode Overview")
with gr.Column():
output_image = gr.Image(label="Visualization")
info_text = gr.Markdown()
visualize_btn.click(
fn=visualize_episode,
inputs=[episode_slider, frame_slider],
outputs=[output_image, info_text]
)
overview_btn.click(
fn=get_episode_overview,
inputs=[episode_slider],
outputs=[output_image]
)
gr.Markdown("""
### 🎯 Features
- **Red Box:** Hand bounding box detection
- **Yellow Arrow:** Hand motion direction (action)
- **Browse:** Use sliders to explore different episodes and frames
### πŸ“š About
Ego2Robot converts egocentric factory video into robot-ready training data.
- [GitHub](https://github.com/YOUR_USERNAME/ego2robot)
- [Dataset](https://huggingface.co/datasets/msunbot1/ego2robot-factory-episodes)
- [Blog Post](YOUR_BLOG_URL)
""")
if __name__ == "__main__":
demo.launch()