end_frame_vis / app.py
KaushikSid
Step 6: Add success/failure trajectory filtering from metadata
d28bd63
import gradio as gr
import cv2
import numpy as np
import pandas as pd
import random
import os
import shutil
from pathlib import Path
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from tqdm import tqdm
# Step 6: Add success/failure filtering
def sample_trajectories(dataset_repo, config_name, is_robot, quality_filter, num_samples, max_to_check=10000):
"""Sample random trajectories from HuggingFace dataset with quality filter."""
try:
if config_name:
dataset = load_dataset(dataset_repo, config_name, split="train", streaming=True)
else:
dataset = load_dataset(dataset_repo, split="train", streaming=True)
matching = []
for i, sample in enumerate(dataset):
if i >= max_to_check:
break
# Check robot/human
if sample.get("is_robot", False) != is_robot:
continue
# Check quality/success if filter is applied
if quality_filter != "All":
quality_label = sample.get("quality_label", "")
partial_success = sample.get("partial_success", None)
if quality_filter == "Success":
# Check for success indicators
if quality_label and "success" not in quality_label.lower():
if partial_success is None or partial_success < 1:
continue
elif quality_filter == "Failure":
# Check for failure indicators
if quality_label and "success" in quality_label.lower():
continue
if partial_success is not None and partial_success >= 1:
continue
matching.append(sample)
if len(matching) == 0:
return []
if len(matching) <= num_samples:
random.shuffle(matching)
return matching
return random.sample(matching, num_samples)
except Exception as e:
print(f"Error sampling: {e}")
return []
def download_video(trajectory, dataset_repo, config_name=None):
"""Download video for a trajectory."""
video_path = trajectory.get("frames")
if not video_path:
return None
cache_dir = Path("video_cache")
repo_key = f"{dataset_repo}_{config_name}" if config_name else dataset_repo
repo_key = repo_key.replace("/", "_").replace("\\", "_")
dataset_cache_dir = cache_dir / repo_key
dataset_cache_dir.mkdir(parents=True, exist_ok=True)
local_video_path = dataset_cache_dir / Path(video_path).name
if local_video_path.exists():
return str(local_video_path)
try:
downloaded_path = hf_hub_download(
repo_id=dataset_repo,
repo_type="dataset",
filename=video_path,
local_dir=str(dataset_cache_dir),
local_dir_use_symlinks=False,
)
if Path(downloaded_path).exists():
return downloaded_path
return None
except Exception as e:
print(f"Error downloading: {e}")
return None
def extract_frame(video_path, frame_num):
"""Extract a specific frame from video."""
if not video_path or not os.path.exists(video_path):
return None, "No video loaded", "0.0%"
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_num >= total_frames:
frame_num = total_frames - 1
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
cap.release()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
percent = (frame_num / total_frames * 100) if total_frames > 0 else 0
return frame_rgb, f"Frame {frame_num}/{total_frames-1}", f"{percent:.1f}%"
return None, "Error reading frame", "0.0%"
# Global state
current_trajectories = []
current_idx = 0
labels_df = pd.DataFrame(columns=[
"dataset_repo", "config_name", "trajectory_id", "is_robot", "quality_label",
"task", "manual_end_frame", "manual_end_percent", "notes"
])
def load_labels():
"""Load existing labels from CSV."""
global labels_df
if Path("labels.csv").exists():
labels_df = pd.read_csv("labels.csv")
# Add quality_label column if missing
if 'quality_label' not in labels_df.columns:
labels_df['quality_label'] = ''
def save_labels():
"""Save labels to CSV."""
global labels_df
labels_df.to_csv("labels.csv", index=False)
def load_dataset_trajectories(dataset_repo, config_name, quality_filter, num_human, num_robot):
"""Load and download trajectories from dataset."""
global current_trajectories, current_idx
config = config_name.strip() if config_name else None
try:
human_trajs = sample_trajectories(dataset_repo, config, is_robot=False, quality_filter=quality_filter, num_samples=int(num_human))
robot_trajs = sample_trajectories(dataset_repo, config, is_robot=True, quality_filter=quality_filter, num_samples=int(num_robot))
all_trajs = human_trajs + robot_trajs
if not all_trajs:
return f"No {quality_filter.lower()} trajectories found", None, "No video", "", "0.0%", None, ""
current_trajectories = []
for traj in all_trajs:
local_path = download_video(traj, dataset_repo, config)
if local_path:
traj["local_video_path"] = local_path
traj["dataset_repo"] = dataset_repo
traj["config_name"] = config
current_trajectories.append(traj)
current_idx = 0
if current_trajectories:
first_traj = current_trajectories[0]
video_path = first_traj.get("local_video_path")
task = first_traj.get("task", "No task description")
is_robot_str = "Robot" if first_traj.get("is_robot") else "Human"
quality = first_traj.get("quality_label", "Unknown")
cap = cv2.VideoCapture(video_path)
max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
cap.release()
traj_info = f"Trajectory 1/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}"
frame, frame_text, percent = extract_frame(video_path, 0)
return (
f"✅ Loaded {len(current_trajectories)} {quality_filter.lower()} trajectories ({len(human_trajs)} human, {len(robot_trajs)} robot)",
gr.update(maximum=max_frames, value=0),
video_path,
task,
percent,
frame,
traj_info
)
return "No videos downloaded", None, None, "", "0.0%", None, ""
except Exception as e:
return f"❌ Error: {str(e)}", None, None, "", "0.0%", None, ""
def save_label(dataset_repo, config_name, end_frame, notes):
"""Save label for current trajectory."""
global current_trajectories, current_idx, labels_df
if not current_trajectories or current_idx >= len(current_trajectories):
return "No trajectory loaded"
traj = current_trajectories[current_idx]
video_path = traj.get("local_video_path")
if not video_path:
return "No video path"
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
end_percent = (int(end_frame) / total_frames * 100) if total_frames > 0 else 0
# Check if label exists
mask = (
(labels_df['dataset_repo'] == dataset_repo) &
(labels_df['config_name'] == (config_name or "")) &
(labels_df['trajectory_id'] == traj.get('id'))
)
if mask.any():
idx = labels_df[mask].index[0]
labels_df.at[idx, 'manual_end_frame'] = int(end_frame)
labels_df.at[idx, 'manual_end_percent'] = end_percent
labels_df.at[idx, 'notes'] = notes
save_labels()
return f"✅ Updated: Frame {int(end_frame)} ({end_percent:.1f}%)"
new_row = pd.DataFrame([{
"dataset_repo": dataset_repo,
"config_name": config_name or "",
"trajectory_id": traj.get('id'),
"is_robot": traj.get('is_robot', False),
"quality_label": traj.get('quality_label', ''),
"task": traj.get('task', ''),
"manual_end_frame": int(end_frame),
"manual_end_percent": end_percent,
"notes": notes
}])
labels_df = pd.concat([labels_df, new_row], ignore_index=True)
save_labels()
return f"✅ Saved: Frame {int(end_frame)} ({end_percent:.1f}%)"
def navigate_next():
"""Go to next trajectory."""
global current_idx
if not current_trajectories or current_idx >= len(current_trajectories) - 1:
return "No more trajectories", None, "", "0.0%", None, ""
current_idx += 1
traj = current_trajectories[current_idx]
video_path = traj.get("local_video_path")
task = traj.get("task", "No task description")
is_robot_str = "Robot" if traj.get("is_robot") else "Human"
quality = traj.get("quality_label", "Unknown")
cap = cv2.VideoCapture(video_path)
max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
cap.release()
traj_info = f"Trajectory {current_idx+1}/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}"
frame, frame_text, percent = extract_frame(video_path, 0)
return gr.update(maximum=max_frames, value=0), video_path, task, percent, frame, traj_info
def navigate_prev():
"""Go to previous trajectory."""
global current_idx
if not current_trajectories or current_idx <= 0:
return "No previous trajectories", None, "", "0.0%", None, ""
current_idx -= 1
traj = current_trajectories[current_idx]
video_path = traj.get("local_video_path")
task = traj.get("task", "No task description")
is_robot_str = "Robot" if traj.get("is_robot") else "Human"
quality = traj.get("quality_label", "Unknown")
cap = cv2.VideoCapture(video_path)
max_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
cap.release()
traj_info = f"Trajectory {current_idx+1}/{len(current_trajectories)} | Type: {is_robot_str} | Quality: {quality}"
frame, frame_text, percent = extract_frame(video_path, 0)
return gr.update(maximum=max_frames, value=0), video_path, task, percent, frame, traj_info
def analyze_patterns(dataset_repo, config_name):
"""Analyze patterns - requires all trajectories to be labeled."""
global current_trajectories, labels_df
if not current_trajectories:
return {"error": "No trajectories loaded"}
traj = current_trajectories[current_idx]
is_robot = traj.get('is_robot', False)
expected_count = len([t for t in current_trajectories if t.get('is_robot', False) == is_robot])
filtered = labels_df[
(labels_df['dataset_repo'] == dataset_repo) &
(labels_df['config_name'] == (config_name or "")) &
(labels_df['is_robot'] == is_robot)
]
labeled_count = len(filtered)
if labeled_count < expected_count:
return {
"error": True,
"message": f"Only {labeled_count}/{expected_count} {'robot' if is_robot else 'human'} trajectories labeled. Label all before analyzing.",
"labeled_count": labeled_count,
"expected_count": expected_count
}
if labeled_count < 3:
return {
"error": True,
"message": "Need at least 3 labels",
"labeled_count": labeled_count
}
percents = filtered['manual_end_percent'].values
result = {
"pattern_found": np.std(percents) < 10,
"mean_percent": round(float(np.mean(percents)), 2),
"median_percent": round(float(np.median(percents)), 2),
"std_percent": round(float(np.std(percents)), 2),
"min_percent": round(float(np.min(percents)), 2),
"max_percent": round(float(np.max(percents)), 2),
"quantile_90": round(float(np.percentile(percents, 90)), 2),
"count": labeled_count,
"suggested_label": round(float(np.mean(percents)))
}
if result["pattern_found"]:
result["message"] = f"✅ Pattern detected! Low variance ({result['std_percent']}%)"
else:
result["message"] = f"⚠️ High variance ({result['std_percent']}%) - no clear pattern"
return result
# Load existing labels on startup
load_labels()
with gr.Blocks(title="Trajectory End Point Labeler") as demo:
gr.Markdown("# Trajectory End Point Labeler")
gr.Markdown("Label trajectory end points and analyze patterns")
with gr.Row():
with gr.Column(scale=1):
dataset_input = gr.Textbox(
label="Dataset Repository",
value="jesbu1/epic_rfm",
placeholder="jesbu1/epic_rfm"
)
config_input = gr.Textbox(
label="Config Name (optional)",
placeholder="Leave empty if no config"
)
quality_filter = gr.Radio(
choices=["All", "Success", "Failure"],
value="All",
label="Trajectory Quality Filter"
)
num_human = gr.Number(label="Human Samples", value=10, precision=0)
num_robot = gr.Number(label="Robot Samples", value=10, precision=0)
load_btn = gr.Button("Load Dataset", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
traj_info = gr.Textbox(label="Current Trajectory", interactive=False)
task_display = gr.Textbox(label="Task Description", interactive=False)
with gr.Row():
prev_btn = gr.Button("← Previous")
next_btn = gr.Button("Next →")
video_player = gr.Video(label="Trajectory Video")
frame_slider = gr.Slider(minimum=0, maximum=63, step=1, value=0, label="Frame Number")
frame_display = gr.Image(label="Current Frame")
frame_info = gr.Textbox(label="Frame Info", interactive=False)
with gr.Row():
end_frame_input = gr.Number(label="End Frame", value=0, precision=0)
end_percent = gr.Textbox(label="End Percent", interactive=False)
notes_input = gr.Textbox(label="Notes (optional)", placeholder="Add notes...")
save_btn = gr.Button("Save Label", variant="primary")
save_status = gr.Textbox(label="Save Status", interactive=False)
# Pattern analysis section
with gr.Row():
with gr.Column():
gr.Markdown("### Pattern Analysis")
gr.Markdown("Requires all trajectories of same type to be labeled")
analyze_btn = gr.Button("Analyze Pattern", variant="secondary")
pattern_output = gr.JSON(label="Pattern Metrics")
# Load dataset
load_btn.click(
load_dataset_trajectories,
inputs=[dataset_input, config_input, quality_filter, num_human, num_robot],
outputs=[status, frame_slider, video_player, task_display, end_percent, frame_display, traj_info]
)
# Navigate
next_btn.click(
navigate_next,
outputs=[frame_slider, video_player, task_display, end_percent, frame_display, traj_info]
)
prev_btn.click(
navigate_prev,
outputs=[frame_slider, video_player, task_display, end_percent, frame_display, traj_info]
)
# Frame navigation
frame_slider.change(
extract_frame,
inputs=[video_player, frame_slider],
outputs=[frame_display, frame_info, end_percent]
)
video_player.change(
lambda v: extract_frame(v, 0) if v else (None, "No video", "0.0%"),
inputs=[video_player],
outputs=[frame_display, frame_info, end_percent]
)
# Update percent when end frame changes
end_frame_input.change(
lambda v, f: (None, "No video", "0.0%")[2] if not v else f"{(int(f) / int(cv2.VideoCapture(v).get(cv2.CAP_PROP_FRAME_COUNT)) * 100):.1f}%" if os.path.exists(v) and int(cv2.VideoCapture(v).get(cv2.CAP_PROP_FRAME_COUNT)) > 0 else "0.0%",
inputs=[video_player, end_frame_input],
outputs=[end_percent]
)
# Save label
save_btn.click(
save_label,
inputs=[dataset_input, config_input, end_frame_input, notes_input],
outputs=[save_status]
)
# Pattern analysis
analyze_btn.click(
analyze_patterns,
inputs=[dataset_input, config_input],
outputs=[pattern_output]
)
demo.launch()