Spaces:
Configuration error
Configuration error
| """Gradio Space for browsing Ego2Robot episodes.""" | |
| import gradio as gr | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from io import BytesIO | |
| from PIL import Image | |
| # Download dataset files | |
| REPO_ID = "msunbot1/ego2robot-factory-episodes" | |
| def load_episode(episode_idx): | |
| """Load episode from HF Hub.""" | |
| filename = f"data/episode_{episode_idx:06d}.npz" | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=filename, | |
| repo_type="dataset" | |
| ) | |
| return np.load(file_path) | |
| except Exception as e: | |
| return None | |
| def load_metadata(): | |
| """Load dataset metadata.""" | |
| try: | |
| info_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="meta/info.json", | |
| repo_type="dataset" | |
| ) | |
| with open(info_path) as f: | |
| return json.load(f) | |
| except: | |
| return {"total_episodes": 50, "total_frames": 1800, "fps": 6} | |
| # Load metadata | |
| metadata = load_metadata() | |
| def visualize_episode(episode_idx, frame_idx): | |
| """Create visualization for a specific frame.""" | |
| ep = load_episode(episode_idx) | |
| if ep is None: | |
| return None, "Episode not found" | |
| num_frames = len(ep['frame_index']) | |
| frame_idx = min(frame_idx, num_frames - 1) | |
| # Get frame data | |
| img = ep['observation.images.top'][frame_idx] | |
| bbox = ep['observation.state'][frame_idx] | |
| action = ep['action'][frame_idx] | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.imshow(img) | |
| # Draw hand bbox if visible | |
| if bbox[2] > 0: | |
| x_min, y_min, x_max, y_max = bbox | |
| x_min *= 640 | |
| y_min *= 360 | |
| x_max *= 640 | |
| y_max *= 360 | |
| rect = patches.Rectangle( | |
| (x_min, y_min), | |
| x_max - x_min, | |
| y_max - y_min, | |
| linewidth=3, | |
| edgecolor='red', | |
| facecolor='none' | |
| ) | |
| ax.add_patch(rect) | |
| # Add action arrow | |
| center_x = (x_min + x_max) / 2 | |
| center_y = (y_min + y_max) / 2 | |
| dx = action[0] * 100 # Scale for visibility | |
| dy = action[1] * 100 | |
| ax.arrow(center_x, center_y, dx, dy, | |
| head_width=20, head_length=20, | |
| fc='yellow', ec='yellow', linewidth=2) | |
| ax.set_title(f"Episode {episode_idx} | Frame {frame_idx}/{num_frames-1}\n" | |
| f"Action: [{action[0]:.3f}, {action[1]:.3f}]", | |
| fontsize=12, pad=10) | |
| ax.axis('off') | |
| # Save to buffer | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) | |
| buf.seek(0) | |
| plt.close() | |
| # Episode info | |
| info_text = f""" | |
| **Episode {episode_idx} Information:** | |
| - Total Frames: {num_frames} | |
| - Current Frame: {frame_idx} | |
| - Timestamp: {ep['timestamp'][frame_idx]:.2f}s | |
| - Hand Visible: {'Yes' if bbox[2] > 0 else 'No'} | |
| - Action Magnitude: {np.linalg.norm(action):.3f} | |
| """ | |
| return Image.open(buf), info_text | |
| def get_episode_overview(episode_idx): | |
| """Get overview visualization of entire episode.""" | |
| ep = load_episode(episode_idx) | |
| if ep is None: | |
| return None | |
| num_frames = len(ep['frame_index']) | |
| # Sample 8 frames | |
| indices = np.linspace(0, num_frames-1, 8, dtype=int) | |
| fig, axes = plt.subplots(2, 4, figsize=(16, 8)) | |
| axes = axes.flatten() | |
| for i, idx in enumerate(indices): | |
| ax = axes[i] | |
| img = ep['observation.images.top'][idx] | |
| bbox = ep['observation.state'][idx] | |
| action = ep['action'][idx] | |
| ax.imshow(img) | |
| # Draw bbox | |
| if bbox[2] > 0: | |
| x_min, y_min, x_max, y_max = bbox * [640, 360, 640, 360] | |
| rect = patches.Rectangle( | |
| (x_min, y_min), x_max - x_min, y_max - y_min, | |
| linewidth=2, edgecolor='red', facecolor='none' | |
| ) | |
| ax.add_patch(rect) | |
| ax.set_title(f"Frame {idx}", fontsize=9) | |
| ax.axis('off') | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Ego2Robot Episode Browser") as demo: | |
| gr.Markdown("# π€ Ego2Robot: Factory Episode Browser") | |
| gr.Markdown(f""" | |
| Browse 50 episodes of factory manipulation tasks from the [Ego2Robot dataset](https://huggingface.co/datasets/{REPO_ID}). | |
| **Dataset Stats:** | |
| - Episodes: {metadata['total_episodes']} | |
| - Total Frames: {metadata['total_frames']} | |
| - FPS: {metadata['fps']} | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| episode_slider = gr.Slider( | |
| minimum=0, | |
| maximum=metadata['total_episodes']-1, | |
| step=1, | |
| value=0, | |
| label="Episode" | |
| ) | |
| frame_slider = gr.Slider( | |
| minimum=0, | |
| maximum=35, | |
| step=1, | |
| value=0, | |
| label="Frame" | |
| ) | |
| visualize_btn = gr.Button("π Visualize Frame", variant="primary") | |
| overview_btn = gr.Button("π Episode Overview") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Visualization") | |
| info_text = gr.Markdown() | |
| visualize_btn.click( | |
| fn=visualize_episode, | |
| inputs=[episode_slider, frame_slider], | |
| outputs=[output_image, info_text] | |
| ) | |
| overview_btn.click( | |
| fn=get_episode_overview, | |
| inputs=[episode_slider], | |
| outputs=[output_image] | |
| ) | |
| gr.Markdown(""" | |
| ### π― Features | |
| - **Red Box:** Hand bounding box detection | |
| - **Yellow Arrow:** Hand motion direction (action) | |
| - **Browse:** Use sliders to explore different episodes and frames | |
| ### π About | |
| Ego2Robot converts egocentric factory video into robot-ready training data. | |
| - [GitHub](https://github.com/YOUR_USERNAME/ego2robot) | |
| - [Dataset](https://huggingface.co/datasets/msunbot1/ego2robot-factory-episodes) | |
| - [Blog Post](YOUR_BLOG_URL) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |