Spaces:
Sleeping
Sleeping
| # ------------------ Import Libraries ------------------ | |
| import dash | |
| from dash import dcc, html, Input, Output, State, no_update | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| from scipy.ndimage import gaussian_filter1d | |
| import requests | |
| import json | |
| import tempfile | |
| import os | |
| from urllib.parse import urljoin | |
| import subprocess | |
| # ------------------ Data Download and Processing ------------------ | |
| class RemoteDatasetLoader: | |
| def __init__(self, repo_id: str, timeout: int = 30): | |
| self.repo_id = repo_id | |
| self.timeout = timeout | |
| self.base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" | |
| def _get_dataset_info(self) -> dict: | |
| info_url = urljoin(self.base_url, "meta/info.json") | |
| response = requests.get(info_url, timeout=self.timeout) | |
| response.raise_for_status() | |
| return response.json() | |
| def _get_episode_info(self, episode_id: int) -> dict: | |
| episodes_url = urljoin(self.base_url, "meta/episodes.jsonl") | |
| response = requests.get(episodes_url, timeout=self.timeout) | |
| response.raise_for_status() | |
| episodes = [json.loads(line) for line in response.text.splitlines() if line.strip()] | |
| for episode in episodes: | |
| if episode.get("episode_index") == episode_id: | |
| return episode | |
| raise ValueError(f"Episode {episode_id} not found") | |
| def _is_valid_mp4(self, file_path): | |
| if not os.path.exists(file_path) or os.path.getsize(file_path) < 1024 * 100: | |
| return False | |
| # Use ffprobe to check if it is a valid mp4 | |
| try: | |
| result = subprocess.run([ | |
| 'ffprobe', '-v', 'error', '-select_streams', 'v:0', | |
| '-show_entries', 'stream=codec_name', '-of', 'default=noprint_wrappers=1:nokey=1', file_path | |
| ], capture_output=True, text=True, timeout=10) | |
| if result.returncode == 0 and '264' in result.stdout: | |
| return True | |
| except Exception as e: | |
| print(f"ffprobe video check failed: {e}") | |
| return False | |
| def _download_video(self, video_url: str, save_path: str) -> str: | |
| response = requests.get(video_url, timeout=self.timeout, stream=True) | |
| response.raise_for_status() | |
| # Check Content-Type | |
| if 'video' not in response.headers.get('Content-Type', ''): | |
| raise ValueError(f"URL {video_url} does not return video content, Content-Type: {response.headers.get('Content-Type')}") | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| with open(save_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return save_path | |
| def load_episode_data(self, episode_id: int, | |
| video_keys=None, | |
| download_dir=None): | |
| dataset_info = self._get_dataset_info() | |
| self._get_episode_info(episode_id) # Check if episode exists | |
| if download_dir is None: | |
| download_dir = tempfile.mkdtemp(prefix="lerobot_videos_") | |
| if video_keys is None: | |
| video_keys = [key for key, feature in dataset_info["features"].items() | |
| if feature["dtype"] == "video"] | |
| video_keys = video_keys[:2] | |
| video_paths = [] | |
| chunks_size = dataset_info.get("chunks_size", 1000) | |
| # Create repo-specific subdirectory | |
| repo_name = self.repo_id.replace('/', '_') # Replace / with _ to avoid path issues | |
| repo_dir = os.path.join(download_dir, repo_name) | |
| os.makedirs(repo_dir, exist_ok=True) | |
| for i, video_key in enumerate(video_keys): | |
| video_url = self.base_url + dataset_info["video_path"].format( | |
| episode_chunk=episode_id // chunks_size, | |
| video_key=video_key, | |
| episode_index=episode_id | |
| ) | |
| video_filename = f"episode_{episode_id}_{video_key}.mp4" | |
| local_path = os.path.join(repo_dir, video_filename) | |
| # Prefer loading local valid mp4 | |
| if self._is_valid_mp4(local_path): | |
| print(f"Local valid video found: {local_path}") | |
| video_paths.append(local_path) | |
| continue | |
| try: | |
| downloaded_path = self._download_video(video_url, local_path) | |
| video_paths.append(downloaded_path) | |
| except Exception as e: | |
| print(f"Failed to download video {video_key}: {e}") | |
| video_paths.append(video_url) | |
| data_url = self.base_url + dataset_info["data_path"].format( | |
| episode_chunk=episode_id // chunks_size, | |
| episode_index=episode_id | |
| ) | |
| try: | |
| df = pd.read_parquet(data_url) | |
| except Exception as e: | |
| print(f"Failed to load data: {e}") | |
| df = pd.DataFrame() | |
| return video_paths, df | |
| def check_ffmpeg_available(): | |
| try: | |
| result = subprocess.run(['ffmpeg', '-version'], | |
| capture_output=True, text=True, timeout=5) | |
| return result.returncode == 0 | |
| except (subprocess.TimeoutExpired, FileNotFoundError): | |
| return False | |
| def get_video_codec_info(video_path): | |
| try: | |
| result = subprocess.run([ | |
| 'ffprobe', '-v', 'quiet', '-print_format', 'json', | |
| '-show_streams', video_path | |
| ], capture_output=True, text=True, timeout=10) | |
| if result.returncode == 0: | |
| info = json.loads(result.stdout) | |
| for stream in info.get('streams', []): | |
| if stream.get('codec_type') == 'video': | |
| return stream.get('codec_name', 'unknown') | |
| except Exception as e: | |
| print(f"Failed to get video codec info: {e}") | |
| return 'unknown' | |
| def reencode_video_to_h264(input_path, output_path=None, quality='medium'): | |
| if output_path is None: | |
| base_name = os.path.splitext(input_path)[0] | |
| output_path = f"{base_name}_h264.mp4" | |
| quality_params = { | |
| 'fast': ['-preset', 'ultrafast', '-crf', '28'], | |
| 'medium': ['-preset', 'medium', '-crf', '23'], | |
| 'high': ['-preset', 'slow', '-crf', '18'] | |
| } | |
| params = quality_params.get(quality, quality_params['medium']) | |
| try: | |
| cmd = [ | |
| 'ffmpeg', '-i', input_path, | |
| '-c:v', 'libx264', | |
| '-c:a', 'aac', | |
| '-movflags', '+faststart', | |
| '-y', | |
| ] + params + [output_path] | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
| if result.returncode == 0: | |
| return output_path | |
| else: | |
| print(f"Re-encoding failed: {result.stderr}") | |
| return input_path | |
| except subprocess.TimeoutExpired: | |
| print("Re-encoding timeout") | |
| return input_path | |
| except Exception as e: | |
| print(f"Re-encoding exception: {e}") | |
| return input_path | |
| def process_video_for_compatibility(video_path): | |
| if not os.path.exists(video_path): | |
| print(f"Video file does not exist: {video_path}") | |
| return video_path | |
| if not check_ffmpeg_available(): | |
| print("ffmpeg not available, skipping re-encoding") | |
| return video_path | |
| codec = get_video_codec_info(video_path) | |
| if codec in ['av01', 'av1', 'vp9', 'vp8'] or codec == 'unknown': | |
| reencoded_path = reencode_video_to_h264(video_path, quality='fast') | |
| if os.path.exists(reencoded_path) and os.path.getsize(reencoded_path) > 1024: | |
| return reencoded_path | |
| else: | |
| print("Re-encoding failed, using original file") | |
| return video_path | |
| else: | |
| return video_path | |
| def load_remote_dataset(repo_id: str, | |
| episode_id: int = 0, | |
| video_keys=None, | |
| download_dir=None): | |
| loader = RemoteDatasetLoader(repo_id) | |
| video_paths, df = loader.load_episode_data(episode_id, video_keys, download_dir) | |
| processed_video_paths = [] | |
| for video_path in video_paths: | |
| processed_path = process_video_for_compatibility(video_path) | |
| processed_video_paths.append(processed_path) | |
| return processed_video_paths, df | |
| # ------------------ Dash Initialization ------------------ | |
| app = dash.Dash(__name__, suppress_callback_exceptions=True) | |
| server = app.server | |
| # ------------------ Page Layout ------------------ | |
| app.layout = html.Div([ | |
| # Header with gradient background | |
| html.Div([ | |
| html.H1("Keyframe Identification", | |
| style={ | |
| "textAlign": "center", | |
| "marginBottom": "10px", | |
| "color": "white", | |
| "fontSize": "2.5rem", | |
| "fontWeight": "300", | |
| "textShadow": "2px 2px 4px rgba(0,0,0,0.3)" | |
| }), | |
| html.P("Interactive Joint Analysis with Video Synchronization", | |
| style={ | |
| "textAlign": "center", | |
| "color": "rgba(255,255,255,0.9)", | |
| "fontSize": "1.1rem", | |
| "marginBottom": "0" | |
| }) | |
| ], style={ | |
| "background": "linear-gradient(135deg, #667eea 0%, #764ba2 100%)", | |
| "padding": "30px 20px", | |
| "marginBottom": "30px", | |
| "borderRadius": "0 0 15px 15px", | |
| "boxShadow": "0 4px 20px rgba(0,0,0,0.1)" | |
| }), | |
| # Control Panel | |
| html.Div([ | |
| html.Div([ | |
| html.Label("Repository ID:", | |
| style={ | |
| "fontWeight": "600", | |
| "color": "#333", | |
| "marginRight": "10px", | |
| "fontSize": "1rem" | |
| }), | |
| dcc.Input( | |
| id="input-repo-id", | |
| type="text", | |
| value="zijian2022/sortingtest", | |
| style={ | |
| "width": "350px", | |
| "padding": "12px 15px", | |
| "border": "2px solid #e1e5e9", | |
| "borderRadius": "8px", | |
| "fontSize": "14px", | |
| "transition": "border-color 0.3s ease", | |
| "outline": "none" | |
| }, | |
| placeholder="Enter HuggingFace dataset repository ID" | |
| ), | |
| ], style={"marginBottom": "15px"}), | |
| html.Div([ | |
| html.Label("Episode ID:", | |
| style={ | |
| "fontWeight": "600", | |
| "color": "#333", | |
| "marginRight": "10px", | |
| "fontSize": "1rem" | |
| }), | |
| dcc.Input( | |
| id="input-episode-id", | |
| type="number", | |
| value=0, | |
| min=0, | |
| style={ | |
| "width": "120px", | |
| "padding": "12px 15px", | |
| "border": "2px solid #e1e5e9", | |
| "borderRadius": "8px", | |
| "fontSize": "14px", | |
| "transition": "border-color 0.3s ease", | |
| "outline": "none" | |
| } | |
| ), | |
| html.Button( | |
| "Load Data", | |
| id="btn-load", | |
| n_clicks=0, | |
| style={ | |
| "marginLeft": "20px", | |
| "padding": "12px 25px", | |
| "backgroundColor": "#667eea", | |
| "color": "white", | |
| "border": "none", | |
| "borderRadius": "8px", | |
| "fontSize": "14px", | |
| "fontWeight": "600", | |
| "cursor": "pointer", | |
| "transition": "all 0.3s ease", | |
| "boxShadow": "0 2px 10px rgba(102, 126, 234, 0.3)" | |
| } | |
| ), | |
| ]), | |
| ], style={ | |
| "textAlign": "center", | |
| "marginBottom": "40px", | |
| "padding": "25px", | |
| "backgroundColor": "white", | |
| "borderRadius": "12px", | |
| "boxShadow": "0 4px 20px rgba(0,0,0,0.08)", | |
| "border": "1px solid #f0f0f0" | |
| }), | |
| # Loading and Data Store | |
| dcc.Loading( | |
| id="loading", | |
| type="circle", | |
| style={"margin": "20px auto"}, | |
| children=dcc.Store(id="store-data") | |
| ), | |
| # Main Content Area | |
| html.Div( | |
| id="main-content", | |
| style={ | |
| "backgroundColor": "#f8f9fa", | |
| "minHeight": "400px", | |
| "borderRadius": "12px", | |
| "padding": "20px" | |
| } | |
| ), | |
| ], style={ | |
| "fontFamily": "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif", | |
| "backgroundColor": "#f5f7fa", | |
| "minHeight": "100vh", | |
| "padding": "0" | |
| }) | |
| # ------------------ Data Loading Callback ------------------ | |
| def load_data_callback(n_clicks, repo_id, episode_id): | |
| try: | |
| video_paths, data_df = load_remote_dataset( | |
| repo_id=repo_id, | |
| episode_id=int(episode_id), | |
| download_dir="./downloaded_videos" | |
| ) | |
| if data_df is None or data_df.empty: | |
| return {} | |
| return { | |
| "video_paths": video_paths, | |
| "data_df": data_df.to_dict("records"), | |
| "columns": ["shoulder_pan", "shoulder_pitch", "elbow", "wrist_pitch", "wrist_roll", "gripper"], | |
| "timestamps": data_df["timestamp"].tolist() | |
| } | |
| except Exception as e: | |
| print(f"Data loading error: {e}") | |
| return {} | |
| # ------------------ Main Content Rendering Callback ------------------ | |
| def update_main_content(data): | |
| if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
| return html.Div([ | |
| html.Div("📊", style={"fontSize": "3rem", "marginBottom": "20px"}), | |
| html.H3("No Data Available", style={"color": "#666", "marginBottom": "10px"}), | |
| html.P("Please click the 'Load Data' button above to get data.", | |
| style={"color": "#888", "fontSize": "1rem"}) | |
| ], style={ | |
| "textAlign": "center", | |
| "padding": "60px 20px", | |
| "color": "#666" | |
| }) | |
| columns = data["columns"] | |
| rows = [] | |
| for i, joint in enumerate(columns): | |
| rows.append(html.Div([ | |
| # Joint Graph - Left 50% | |
| html.Div([ | |
| dcc.Graph(id=f"graph-{i}") | |
| ], style={ | |
| "flex": "0 0 50%", | |
| "backgroundColor": "white", | |
| "borderRadius": "8px", | |
| "padding": "8px", | |
| "boxShadow": "0 2px 10px rgba(0,0,0,0.05)", | |
| "border": "1px solid #e9ecef", | |
| "marginRight": "2%" | |
| }), | |
| # Video Area - Right 48% | |
| html.Div([ | |
| html.Img(id=f"video1-{i}", style={ | |
| "width": "49%", | |
| "height": "180px", | |
| "objectFit": "contain", | |
| "display": "inline-block", | |
| "borderRadius": "6px", | |
| "border": "2px solid #e9ecef" | |
| }), | |
| html.Img(id=f"video2-{i}", style={ | |
| "width": "49%", | |
| "height": "180px", | |
| "objectFit": "contain", | |
| "display": "inline-block", | |
| "borderRadius": "6px", | |
| "border": "2px solid #e9ecef" | |
| }) | |
| ], style={ | |
| "flex": "0 0 48%" | |
| }) | |
| ], style={ | |
| "marginBottom": "25px", | |
| "backgroundColor": "white", | |
| "borderRadius": "12px", | |
| "padding": "12px", | |
| "boxShadow": "0 4px 15px rgba(0,0,0,0.08)", | |
| "border": "1px solid #f0f0f0", | |
| "display": "flex", | |
| "alignItems": "flex-start", | |
| "minHeight": "250px" | |
| })) | |
| return html.Div(rows) | |
| # ------------------ Shadow and Highlight Utility Functions ------------------ | |
| def find_intervals(mask): | |
| intervals = [] | |
| start = None | |
| for i, val in enumerate(mask): | |
| if val and start is None: | |
| start = i | |
| elif not val and start is not None: | |
| intervals.append((start, i - 1)) | |
| start = None | |
| if start is not None: | |
| intervals.append((start, len(mask) - 1)) | |
| return intervals | |
| def get_shadow_info(joint_name, action_df, delta_t, time_for_plot): | |
| angles = action_df[joint_name].values | |
| velocity = np.diff(angles) / delta_t | |
| smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
| smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
| vel_threshold = 0.5 | |
| highlight_width = 1 | |
| k = 2 | |
| shadows = [] | |
| low_speed_mask = np.abs(smoothed_velocity) < vel_threshold | |
| low_speed_intervals = find_intervals(low_speed_mask) | |
| for start, end in low_speed_intervals: | |
| if end - start + 1 <= k: | |
| shadows.append({ | |
| 'type': 'low_speed', | |
| 'start_time': time_for_plot[start], | |
| 'end_time': time_for_plot[end], | |
| 'start_idx': start, | |
| 'end_idx': end | |
| }) | |
| max_idx = np.argmax(smoothed_angle) | |
| s_max = max(0, max_idx - highlight_width) | |
| e_max = min(len(time_for_plot) - 1, max_idx + highlight_width) | |
| shadows.append({ | |
| 'type': 'max_value', | |
| 'start_time': time_for_plot[s_max], | |
| 'end_time': time_for_plot[e_max], | |
| 'start_idx': s_max, | |
| 'end_idx': e_max | |
| }) | |
| min_idx = np.argmin(smoothed_angle) | |
| s_min = max(0, min_idx - highlight_width) | |
| e_min = min(len(time_for_plot) - 1, min_idx + highlight_width) | |
| shadows.append({ | |
| 'type': 'min_value', | |
| 'start_time': time_for_plot[s_min], | |
| 'end_time': time_for_plot[e_min], | |
| 'start_idx': s_min, | |
| 'end_idx': e_min | |
| }) | |
| return shadows | |
| def generate_joint_graph(joint_name, idx, action_df, delta_t, time_for_plot, all_shadows): | |
| angles = action_df[joint_name].values | |
| velocity = np.diff(angles) / delta_t | |
| smoothed_velocity = gaussian_filter1d(velocity, sigma=1) | |
| smoothed_angle = gaussian_filter1d(angles[1:], sigma=1) | |
| shapes = [] | |
| current_shadows = all_shadows[joint_name] | |
| for shadow in current_shadows: | |
| shapes.append({ | |
| "type": "rect", | |
| "xref": "x", | |
| "yref": "paper", | |
| "x0": shadow['start_time'], | |
| "x1": shadow['end_time'], | |
| "y0": 0, | |
| "y1": 1, | |
| "fillcolor": "#ef4444", # Fixed red | |
| "opacity": 0.4, | |
| "line": {"width": 0} | |
| }) | |
| return { | |
| "data": [ | |
| go.Scatter( | |
| x=time_for_plot, | |
| y=smoothed_angle, | |
| name="Joint Angle", | |
| line=dict(color='#f59e0b', width=2), | |
| hovertemplate='<b>Time:</b> %{x:.2f}s<br><b>Angle:</b> %{y:.2f}°<extra></extra>' | |
| ) | |
| ], | |
| "layout": go.Layout( | |
| title={ | |
| 'text': joint_name.replace('_', ' ').title(), | |
| 'font': {'size': 16, 'color': '#374151'} | |
| }, | |
| xaxis={ | |
| "title": "Time (seconds)", | |
| "titlefont": {"color": "#6b7280"}, | |
| "tickfont": {"color": "#6b7280"}, | |
| "gridcolor": "#f3f4f6", | |
| "zerolinecolor": "#e5e7eb" | |
| }, | |
| yaxis={ | |
| "title": "Angle (degrees)", | |
| "titlefont": {"color": "#6b7280"}, | |
| "tickfont": {"color": "#6b7280"}, | |
| "gridcolor": "#f3f4f6", | |
| "zerolinecolor": "#e5e7eb" | |
| }, | |
| shapes=shapes, | |
| hovermode="x unified", | |
| height=220, | |
| margin=dict(t=30, b=30, l=50, r=30), | |
| showlegend=False, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white', | |
| font={'family': "'Segoe UI', Tahoma, Geneva, Verdana, sans-serif"}, | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font_size=12, | |
| font_family="'Segoe UI', Tahoma, Geneva, Verdana, sans-serif" | |
| ) | |
| ) | |
| } | |
| # ------------------ Chart Update Callback ------------------ | |
| def update_all_graphs(data): | |
| if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
| return [no_update] * 6 | |
| columns = data["columns"] | |
| df = pd.DataFrame.from_records(data["data_df"]) | |
| action_df = pd.DataFrame(df["action"].tolist(), columns=columns) | |
| timestamps = df["timestamp"].values | |
| delta_t = np.diff(timestamps) | |
| time_for_plot = timestamps[1:] | |
| all_shadows = {} | |
| for joint in columns: | |
| all_shadows[joint] = get_shadow_info(joint, action_df, delta_t, time_for_plot) | |
| # Generate all charts, no highlight logic | |
| return [ | |
| generate_joint_graph(joint, i, action_df, delta_t, time_for_plot, all_shadows) | |
| for i, joint in enumerate(columns) | |
| ] | |
| # ------------------ Video Frame Extraction Function ------------------ | |
| def get_video_frame(video_path, time_in_seconds): | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print(f"❌ Cannot open video: {video_path}") | |
| return None | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| cap.release() | |
| return None | |
| frame_num = int(time_in_seconds * fps) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
| success, frame = cap.read() | |
| cap.release() | |
| if success and frame is not None: | |
| height, width = frame.shape[:2] | |
| if width > 640: | |
| new_width = 640 | |
| new_height = int(height * (new_width / width)) | |
| frame = cv2.resize(frame, (new_width, new_height)) | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 85] | |
| _, buffer = cv2.imencode('.jpg', frame, encode_param) | |
| encoded = base64.b64encode(buffer).decode('utf-8') | |
| return f"data:image/jpeg;base64,{encoded}" | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"❌ Exception extracting video frame: {e}") | |
| return None | |
| # ------------------ Video Frame Callback ------------------ | |
| for i in range(6): | |
| def update_video_frames(data, hover_data, idx=i): | |
| if not data or "data_df" not in data or len(data["data_df"]) == 0: | |
| return no_update, no_update | |
| columns = data["columns"] | |
| df = pd.DataFrame.from_records(data["data_df"]) | |
| timestamps = df["timestamp"].values | |
| time_for_plot = timestamps[1:] | |
| video_paths = data["video_paths"] | |
| # Determine the time point to display | |
| display_time = 0.0 # Default to start time | |
| if hover_data and "points" in hover_data and len(hover_data["points"]) > 0: | |
| # If there is hover data, use hover time | |
| display_time = float(hover_data["points"][0]["x"]) | |
| elif len(time_for_plot) > 0: | |
| # If no hover data, use the start time of the timeline | |
| display_time = time_for_plot[0] | |
| try: | |
| frame1 = get_video_frame(video_paths[0], display_time) | |
| frame2 = get_video_frame(video_paths[1], display_time) | |
| if frame1 and frame2: | |
| return frame1, frame2 | |
| else: | |
| return no_update, no_update | |
| except Exception as e: | |
| print(f"update_video_frames callback error: {e}") | |
| return no_update, no_update | |
| # ------------------ Start Application ------------------ | |
| if __name__ == "__main__": | |
| app.run(debug=True, host='0.0.0.0', port=7860) |