Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| from tqdm import tqdm | |
| # Step 6: Add success/failure filtering | |
| def sample_trajectories(dataset_repo, config_name, is_robot, quality_filter, num_samples, max_to_check=10000): | |
| """Sample random trajectories from HuggingFace dataset with quality filter.""" | |
| try: | |
| if config_name: | |
| dataset = load_dataset(dataset_repo, config_name, split="train", streaming=True) | |
| else: | |
| dataset = load_dataset(dataset_repo, split="train", streaming=True) | |
| matching = [] | |
| for i, sample in enumerate(dataset): | |
| if i >= max_to_check: | |
| break | |
| # Check robot/human | |
| if sample.get("is_robot", False) != is_robot: | |
| continue | |
| # Check quality/success if filter is applied | |
| if quality_filter != "All": | |
| quality_label = sample.get("quality_label", "") | |
| partial_success = sample.get("partial_success", None) | |
| if quality_filter == "Success": | |
| # Check for success indicators | |
| if quality_label and "success" not in quality_label.lower(): | |
| if partial_success is None or partial_success < 1: | |
| continue | |
| elif quality_filter == "Failure": | |
| # Check for failure indicators | |
| if quality_label and "success" in quality_label.lower(): | |
| continue | |
| if partial_success is not None and partial_success >= 1: | |
| continue | |
| matching.append(sample) | |
| if len(matching) == 0: | |
| return [] | |
| if len(matching) <= num_samples: | |
| random.shuffle(matching) | |
| return matching | |
| return random.sample(matching, num_samples) | |
| except Exception as e: | |
| print(f"Error sampling: {e}") | |
| return [] | |
| def download_video(trajectory, dataset_repo, config_name=None): | |
| """Download video for a trajectory.""" | |
| video_path = trajectory.get("frames") | |
| if not video_path: | |
| return None | |
| cache_dir = Path("video_cache") | |
| repo_key = f"{dataset_repo}_{config_name}" if config_name else dataset_repo | |
| repo_key = repo_key.replace("/", "_").replace("\\", "_") | |
| dataset_cache_dir = cache_dir / repo_key | |
| dataset_cache_dir.mkdir(parents=True, exist_ok=True) | |
| local_video_path = dataset_cache_dir / Path(video_path).name | |
| if local_video_path.exists(): | |
| return str(local_video_path) | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=dataset_repo, | |
| repo_type="dataset", | |
| filename=video_path, | |
| local_dir=str(dataset_cache_dir), | |
| local_dir_use_symlinks=False, | |
| ) | |
| if Path(downloaded_path).exists(): | |
| return downloaded_path | |
| return None | |
| except Exception as e: | |
| print(f"Error downloading: {e}") | |
| return None | |
| def extract_frame(video_path, frame_num): | |
| """Extract a specific frame from video.""" | |
| if not video_path or not os.path.exists(video_path): | |
| return None, "No video loaded", "0.0%" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if frame_num >= total_frames: | |
| frame_num = total_frames - 1 | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| percent = (frame_num / total_frames * 100) if total_frames > 0 else 0 | |
| return frame_rgb, f"Frame {frame_num}/{total_frames-1}", f"{percent:.1f}%" | |
| return None, "Error reading frame", "0.0%" | |
| # Global state | |
| current_trajectories = [] | |
| current_idx = 0 | |
| labels_df = pd.DataFrame(columns=[ | |
| "dataset_repo", "config_name", "trajectory_id", "is_robot", "quality_label", | |
| "task", "manual_end_frame", "manual_end_percent", "notes" | |
| ]) | |
| def load_labels(): | |
| """Load existing labels from CSV.""" | |
| global labels_df | |
| if Path("labels.csv").exists(): | |
| labels_df = pd.read_csv("labels.csv") | |
| # Add quality_label column if missing | |
| if 'quality_label' not in labels_df.columns: | |
| labels_df['quality_label'] = '' | |
| def save_labels(): | |
| """Save labels to CSV.""" | |
| global labels_df | |
| labels_df.to_csv("labels.csv", index=False) | |
| def load_dataset_trajectories(dataset_repo, config_name, quality_filter, num_human, num_robot): | |
| """Load and download trajectories from dataset.""" | |
| global current_trajectories, current_idx | |
| config = config_name.strip() if config_name else None | |
| try: | |
| human_trajs = sample_trajectories(dataset_repo, config, is_robot=False, quality_filter=quality_filter, num_samples=int(num_human)) | |
| robot_trajs = sample_trajectories(dataset_repo, config, is_robot=True, quality_filter=quality_filter, num_samples=int(num_robot)) | |
| all_trajs = human_trajs + robot_trajs | |
| if not all_trajs: | |
| return f"No {quality_filter.lower()} trajectories found", None, "No video", "", "0.0%", None, "" | |
| current_trajectories = [] | |
| for traj in all_trajs: | |
| local_path = download_video(traj, dataset_repo, config) | |
| if local_path: | |
| traj["local_video_path"] = local_path | |
| traj["dataset_repo"] = dataset_repo | |
| traj["config_name"] = config | |
| current_trajectories.append(traj) | |
| current_idx = 0 | |
| if current_trajectories: | |
| first_traj = current_trajectories[0] | |
| video_path = first_traj.get("local_video_path") | |
| task = first_traj.get("task", "No task description") | |
| is_robot_str = "Robot" if first_traj.get("is_robot") else "Human" | |
| quality = first_traj.get("quality_label", "Unknown") | |
| cap = cv2.VideoCapture(video_path) | |
| max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 | |
| cap.release() | |
| traj_info = f"Trajectory 1/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}" | |
| frame, frame_text, percent = extract_frame(video_path, 0) | |
| return ( | |
| f"✅ Loaded {len(current_trajectories)} {quality_filter.lower()} trajectories ({len(human_trajs)} human, {len(robot_trajs)} robot)", | |
| gr.update(maximum=max_frames, value=0), | |
| video_path, | |
| task, | |
| percent, | |
| frame, | |
| traj_info | |
| ) | |
| return "No videos downloaded", None, None, "", "0.0%", None, "" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", None, None, "", "0.0%", None, "" | |
| def save_label(dataset_repo, config_name, end_frame, notes): | |
| """Save label for current trajectory.""" | |
| global current_trajectories, current_idx, labels_df | |
| if not current_trajectories or current_idx >= len(current_trajectories): | |
| return "No trajectory loaded" | |
| traj = current_trajectories[current_idx] | |
| video_path = traj.get("local_video_path") | |
| if not video_path: | |
| return "No video path" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| end_percent = (int(end_frame) / total_frames * 100) if total_frames > 0 else 0 | |
| # Check if label exists | |
| mask = ( | |
| (labels_df['dataset_repo'] == dataset_repo) & | |
| (labels_df['config_name'] == (config_name or "")) & | |
| (labels_df['trajectory_id'] == traj.get('id')) | |
| ) | |
| if mask.any(): | |
| idx = labels_df[mask].index[0] | |
| labels_df.at[idx, 'manual_end_frame'] = int(end_frame) | |
| labels_df.at[idx, 'manual_end_percent'] = end_percent | |
| labels_df.at[idx, 'notes'] = notes | |
| save_labels() | |
| return f"✅ Updated: Frame {int(end_frame)} ({end_percent:.1f}%)" | |
| new_row = pd.DataFrame([{ | |
| "dataset_repo": dataset_repo, | |
| "config_name": config_name or "", | |
| "trajectory_id": traj.get('id'), | |
| "is_robot": traj.get('is_robot', False), | |
| "quality_label": traj.get('quality_label', ''), | |
| "task": traj.get('task', ''), | |
| "manual_end_frame": int(end_frame), | |
| "manual_end_percent": end_percent, | |
| "notes": notes | |
| }]) | |
| labels_df = pd.concat([labels_df, new_row], ignore_index=True) | |
| save_labels() | |
| return f"✅ Saved: Frame {int(end_frame)} ({end_percent:.1f}%)" | |
| def navigate_next(): | |
| """Go to next trajectory.""" | |
| global current_idx | |
| if not current_trajectories or current_idx >= len(current_trajectories) - 1: | |
| return "No more trajectories", None, "", "0.0%", None, "" | |
| current_idx += 1 | |
| traj = current_trajectories[current_idx] | |
| video_path = traj.get("local_video_path") | |
| task = traj.get("task", "No task description") | |
| is_robot_str = "Robot" if traj.get("is_robot") else "Human" | |
| quality = traj.get("quality_label", "Unknown") | |
| cap = cv2.VideoCapture(video_path) | |
| max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 | |
| cap.release() | |
| traj_info = f"Trajectory {current_idx+1}/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}" | |
| frame, frame_text, percent = extract_frame(video_path, 0) | |
| return gr.update(maximum=max_frames, value=0), video_path, task, percent, frame, traj_info | |
| def navigate_prev(): | |
| """Go to previous trajectory.""" | |
| global current_idx | |
| if not current_trajectories or current_idx <= 0: | |
| return "No previous trajectories", None, "", "0.0%", None, "" | |
| current_idx -= 1 | |
| traj = current_trajectories[current_idx] | |
| video_path = traj.get("local_video_path") | |
| task = traj.get("task", "No task description") | |
| is_robot_str = "Robot" if traj.get("is_robot") else "Human" | |
| quality = traj.get("quality_label", "Unknown") | |
| cap = cv2.VideoCapture(video_path) | |
| max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 | |
| cap.release() | |
| traj_info = f"Trajectory {current_idx+1}/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}" | |
| frame, frame_text, percent = extract_frame(video_path, 0) | |
| return gr.update(maximum=max_frames, value=0), video_path, task, percent, frame, traj_info | |
| def analyze_patterns(dataset_repo, config_name): | |
| """Analyze patterns - requires all trajectories to be labeled.""" | |
| global current_trajectories, labels_df | |
| if not current_trajectories: | |
| return {"error": "No trajectories loaded"} | |
| traj = current_trajectories[current_idx] | |
| is_robot = traj.get('is_robot', False) | |
| expected_count = len([t for t in current_trajectories if t.get('is_robot', False) == is_robot]) | |
| filtered = labels_df[ | |
| (labels_df['dataset_repo'] == dataset_repo) & | |
| (labels_df['config_name'] == (config_name or "")) & | |
| (labels_df['is_robot'] == is_robot) | |
| ] | |
| labeled_count = len(filtered) | |
| if labeled_count < expected_count: | |
| return { | |
| "error": True, | |
| "message": f"Only {labeled_count}/{expected_count} {'robot' if is_robot else 'human'} trajectories labeled. Label all before analyzing.", | |
| "labeled_count": labeled_count, | |
| "expected_count": expected_count | |
| } | |
| if labeled_count < 3: | |
| return { | |
| "error": True, | |
| "message": "Need at least 3 labels", | |
| "labeled_count": labeled_count | |
| } | |
| percents = filtered['manual_end_percent'].values | |
| result = { | |
| "pattern_found": np.std(percents) < 10, | |
| "mean_percent": round(float(np.mean(percents)), 2), | |
| "median_percent": round(float(np.median(percents)), 2), | |
| "std_percent": round(float(np.std(percents)), 2), | |
| "min_percent": round(float(np.min(percents)), 2), | |
| "max_percent": round(float(np.max(percents)), 2), | |
| "quantile_90": round(float(np.percentile(percents, 90)), 2), | |
| "count": labeled_count, | |
| "suggested_label": round(float(np.mean(percents))) | |
| } | |
| if result["pattern_found"]: | |
| result["message"] = f"✅ Pattern detected! Low variance ({result['std_percent']}%)" | |
| else: | |
| result["message"] = f"⚠️ High variance ({result['std_percent']}%) - no clear pattern" | |
| return result | |
| # Load existing labels on startup | |
| load_labels() | |
| with gr.Blocks(title="Trajectory End Point Labeler") as demo: | |
| gr.Markdown("# Trajectory End Point Labeler") | |
| gr.Markdown("Label trajectory end points and analyze patterns") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_input = gr.Textbox( | |
| label="Dataset Repository", | |
| value="jesbu1/epic_rfm", | |
| placeholder="jesbu1/epic_rfm" | |
| ) | |
| config_input = gr.Textbox( | |
| label="Config Name (optional)", | |
| placeholder="Leave empty if no config" | |
| ) | |
| quality_filter = gr.Radio( | |
| choices=["All", "Success", "Failure"], | |
| value="All", | |
| label="Trajectory Quality Filter" | |
| ) | |
| num_human = gr.Number(label="Human Samples", value=10, precision=0) | |
| num_robot = gr.Number(label="Robot Samples", value=10, precision=0) | |
| load_btn = gr.Button("Load Dataset", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| traj_info = gr.Textbox(label="Current Trajectory", interactive=False) | |
| task_display = gr.Textbox(label="Task Description", interactive=False) | |
| with gr.Row(): | |
| prev_btn = gr.Button("← Previous") | |
| next_btn = gr.Button("Next →") | |
| video_player = gr.Video(label="Trajectory Video") | |
| frame_slider = gr.Slider(minimum=0, maximum=63, step=1, value=0, label="Frame Number") | |
| frame_display = gr.Image(label="Current Frame") | |
| frame_info = gr.Textbox(label="Frame Info", interactive=False) | |
| with gr.Row(): | |
| end_frame_input = gr.Number(label="End Frame", value=0, precision=0) | |
| end_percent = gr.Textbox(label="End Percent", interactive=False) | |
| notes_input = gr.Textbox(label="Notes (optional)", placeholder="Add notes...") | |
| save_btn = gr.Button("Save Label", variant="primary") | |
| save_status = gr.Textbox(label="Save Status", interactive=False) | |
| # Pattern analysis section | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Pattern Analysis") | |
| gr.Markdown("Requires all trajectories of same type to be labeled") | |
| analyze_btn = gr.Button("Analyze Pattern", variant="secondary") | |
| pattern_output = gr.JSON(label="Pattern Metrics") | |
| # Load dataset | |
| load_btn.click( | |
| load_dataset_trajectories, | |
| inputs=[dataset_input, config_input, quality_filter, num_human, num_robot], | |
| outputs=[status, frame_slider, video_player, task_display, end_percent, frame_display, traj_info] | |
| ) | |
| # Navigate | |
| next_btn.click( | |
| navigate_next, | |
| outputs=[frame_slider, video_player, task_display, end_percent, frame_display, traj_info] | |
| ) | |
| prev_btn.click( | |
| navigate_prev, | |
| outputs=[frame_slider, video_player, task_display, end_percent, frame_display, traj_info] | |
| ) | |
| # Frame navigation | |
| frame_slider.change( | |
| extract_frame, | |
| inputs=[video_player, frame_slider], | |
| outputs=[frame_display, frame_info, end_percent] | |
| ) | |
| video_player.change( | |
| lambda v: extract_frame(v, 0) if v else (None, "No video", "0.0%"), | |
| inputs=[video_player], | |
| outputs=[frame_display, frame_info, end_percent] | |
| ) | |
| # Update percent when end frame changes | |
| end_frame_input.change( | |
| lambda v, f: (None, "No video", "0.0%")[2] if not v else f"{(int(f) / int(cv2.VideoCapture(v).get(cv2.CAP_PROP_FRAME_COUNT)) * 100):.1f}%" if os.path.exists(v) and int(cv2.VideoCapture(v).get(cv2.CAP_PROP_FRAME_COUNT)) > 0 else "0.0%", | |
| inputs=[video_player, end_frame_input], | |
| outputs=[end_percent] | |
| ) | |
| # Save label | |
| save_btn.click( | |
| save_label, | |
| inputs=[dataset_input, config_input, end_frame_input, notes_input], | |
| outputs=[save_status] | |
| ) | |
| # Pattern analysis | |
| analyze_btn.click( | |
| analyze_patterns, | |
| inputs=[dataset_input, config_input], | |
| outputs=[pattern_output] | |
| ) | |
| demo.launch() | |