|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import os |
|
|
from pathlib import Path |
|
|
from datasets import load_dataset, get_dataset_config_names |
|
|
|
|
|
|
|
|
try: |
|
|
from hf_dataset_sync import init_dataset_sync |
|
|
dataset_sync_enabled = init_dataset_sync() |
|
|
except Exception as e: |
|
|
dataset_sync_enabled = False |
|
|
print(f"⚠️ Dataset sync disabled: {e}") |
|
|
|
|
|
|
|
|
PREDEFINED_DATASETS = [ |
|
|
"abraranwar/agibotworld_alpha_rfm", |
|
|
"abraranwar/libero_rfm", |
|
|
"abraranwar/usc_koch_rewind_rfm", |
|
|
"aliangdw/metaworld", |
|
|
"anqil/rh20t_rfm", |
|
|
"anqil/rh20t_subset_rfm", |
|
|
"jesbu1/auto_eval_rfm", |
|
|
"jesbu1/egodex_rfm", |
|
|
"jesbu1/epic_rfm", |
|
|
"jesbu1/fino_net_rfm", |
|
|
"jesbu1/failsafe_rfm", |
|
|
"jesbu1/hand_paired_rfm", |
|
|
"jesbu1/galaxea_rfm", |
|
|
"jesbu1/h2r_rfm", |
|
|
"jesbu1/humanoid_everyday_rfm", |
|
|
"jesbu1/molmoact_rfm", |
|
|
"jesbu1/motif_rfm", |
|
|
"jesbu1/oxe_rfm", |
|
|
"jesbu1/oxe_rfm_eval", |
|
|
"jesbu1/ph2d_rfm", |
|
|
"jesbu1/racer_rfm", |
|
|
"jesbu1/roboarena_0825_rfm", |
|
|
"jesbu1/soar_rfm", |
|
|
"ykorkmaz/libero_failure_rfm", |
|
|
"aliangdw/usc_xarm_policy_ranking", |
|
|
"aliangdw/usc_franka_policy_ranking", |
|
|
"aliangdw/utd_so101_policy_ranking", |
|
|
"aliangdw/utd_so101_human", |
|
|
] |
|
|
|
|
|
|
|
|
current_trajectories = [] |
|
|
current_idx = 0 |
|
|
evaluations_df = pd.DataFrame(columns=[ |
|
|
"dataset_repo", "config_name", "trajectory_id", "task", |
|
|
"decision", "issue_type", "notes", "timestamp" |
|
|
]) |
|
|
|
|
|
|
|
|
def load_evaluations(): |
|
|
global evaluations_df |
|
|
|
|
|
|
|
|
if dataset_sync_enabled: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
DATASET_REPO = os.getenv("EVAL_DATASET_REPO") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
csv_file = hf_hub_download( |
|
|
DATASET_REPO, |
|
|
"traj_evaluations.csv", |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN, |
|
|
force_download=True |
|
|
) |
|
|
evaluations_df = pd.read_csv(csv_file, keep_default_na=False, na_values=['']) |
|
|
|
|
|
evaluations_df = evaluations_df.replace(['nan', 'NaN', 'None'], '') |
|
|
|
|
|
unique_issues = evaluations_df['issue_type'].unique() |
|
|
print(f"📊 Loaded {len(evaluations_df)} evaluations from shared dataset") |
|
|
print(f"🔍 Unique issue_type values: {unique_issues}") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not load from shared dataset: {e}") |
|
|
|
|
|
|
|
|
csv_path = Path("data/evaluations.csv") if os.getenv("SPACE_ID") else Path("evaluations.csv") |
|
|
if csv_path.exists(): |
|
|
evaluations_df = pd.read_csv(csv_path, keep_default_na=False, na_values=['']) |
|
|
|
|
|
evaluations_df = evaluations_df.replace(['nan', 'NaN', 'None'], '') |
|
|
print(f"📊 Loaded {len(evaluations_df)} evaluations from local CSV") |
|
|
|
|
|
|
|
|
def save_evaluations(): |
|
|
|
|
|
df_to_save = evaluations_df.fillna("") |
|
|
if os.getenv("SPACE_ID"): |
|
|
os.makedirs("data", exist_ok=True) |
|
|
df_to_save.to_csv("data/evaluations.csv", index=False) |
|
|
else: |
|
|
df_to_save.to_csv("evaluations.csv", index=False) |
|
|
|
|
|
|
|
|
def get_stats(): |
|
|
total = len(evaluations_df) |
|
|
if total == 0: |
|
|
return "No labels yet" |
|
|
keeps = len(evaluations_df[evaluations_df['decision'] == 'keep']) |
|
|
removes = len(evaluations_df[evaluations_df['decision'] == 'remove']) |
|
|
reviews = len(evaluations_df[evaluations_df['decision'] == 'review']) |
|
|
return f"Total: {total} | ✅ {keeps} | ❌ {removes} | 🔍 {reviews}" |
|
|
|
|
|
|
|
|
def fetch_configs(dataset_repo): |
|
|
"""Fetch configs dynamically (fast API call).""" |
|
|
if not dataset_repo: |
|
|
return gr.update(choices=[], value=""), "", 0, 20 |
|
|
try: |
|
|
configs = get_dataset_config_names(dataset_repo) |
|
|
if configs: |
|
|
return gr.update(choices=configs, value=configs[0]), "", 0, 20 |
|
|
return gr.update(choices=["default"], value="default"), "", 0, 20 |
|
|
except Exception as e: |
|
|
print(f"Config fetch error: {e}") |
|
|
return gr.update(choices=["default"], value="default"), "", 0, 20 |
|
|
|
|
|
|
|
|
def analyze_dataset_progress(dataset_repo, config_name): |
|
|
"""Analyze labeling progress for selected dataset and suggest range.""" |
|
|
if not dataset_repo: |
|
|
return "", 0, 20 |
|
|
|
|
|
config = config_name if config_name and config_name != "default" else None |
|
|
|
|
|
|
|
|
dataset_evals = evaluations_df[ |
|
|
(evaluations_df['dataset_repo'] == dataset_repo) & |
|
|
(evaluations_df['config_name'] == (config if config else '')) |
|
|
] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from datasets import get_dataset_infos |
|
|
try: |
|
|
infos = get_dataset_infos(dataset_repo) |
|
|
config_key = config if config else list(infos.keys())[0] |
|
|
if config_key in infos and infos[config_key].splits.get('train'): |
|
|
dataset_size = infos[config_key].splits['train'].num_examples |
|
|
else: |
|
|
raise Exception("No info available") |
|
|
except: |
|
|
|
|
|
ds = load_dataset(dataset_repo, config, split="train", streaming=True) |
|
|
dataset_size = 0 |
|
|
for i, _ in enumerate(ds): |
|
|
dataset_size = i + 1 |
|
|
if i >= 9999: |
|
|
dataset_size = f"~{dataset_size}" |
|
|
break |
|
|
except Exception as e: |
|
|
dataset_size = "Unknown" |
|
|
|
|
|
if len(dataset_evals) == 0: |
|
|
return f"📊 **No trajectories labeled yet for this dataset**\n\n**Dataset size:** {dataset_size} trajectories", 0, 20 |
|
|
|
|
|
|
|
|
labeled_ids = set(dataset_evals['trajectory_id'].unique()) |
|
|
|
|
|
|
|
|
keeps = len(dataset_evals[dataset_evals['decision'] == 'keep']) |
|
|
removes = len(dataset_evals[dataset_evals['decision'] == 'remove']) |
|
|
reviews = len(dataset_evals[dataset_evals['decision'] == 'review']) |
|
|
|
|
|
|
|
|
try: |
|
|
from datasets import get_dataset_infos |
|
|
try: |
|
|
infos = get_dataset_infos(dataset_repo) |
|
|
config_key = config if config else list(infos.keys())[0] |
|
|
if config_key in infos and infos[config_key].splits.get('train'): |
|
|
dataset_size = infos[config_key].splits['train'].num_examples |
|
|
else: |
|
|
dataset_size = None |
|
|
except: |
|
|
dataset_size = None |
|
|
except: |
|
|
dataset_size = None |
|
|
|
|
|
|
|
|
try: |
|
|
ds = load_dataset(dataset_repo, config, split="train", streaming=True) |
|
|
|
|
|
|
|
|
checked_count = 0 |
|
|
unlabeled_ranges = [] |
|
|
current_unlabeled_start = None |
|
|
|
|
|
for i, sample in enumerate(ds): |
|
|
traj_id = sample.get("id") |
|
|
if traj_id not in labeled_ids: |
|
|
if current_unlabeled_start is None: |
|
|
current_unlabeled_start = i |
|
|
else: |
|
|
if current_unlabeled_start is not None: |
|
|
unlabeled_ranges.append((current_unlabeled_start, i-1)) |
|
|
current_unlabeled_start = None |
|
|
|
|
|
checked_count = i + 1 |
|
|
if dataset_size is None: |
|
|
dataset_size = i + 1 |
|
|
if checked_count >= 1000: |
|
|
if dataset_size is None or dataset_size == checked_count: |
|
|
dataset_size = f"~{checked_count}" |
|
|
break |
|
|
|
|
|
|
|
|
if current_unlabeled_start is not None: |
|
|
unlabeled_ranges.append((current_unlabeled_start, checked_count-1)) |
|
|
|
|
|
|
|
|
if unlabeled_ranges: |
|
|
|
|
|
for start, end in unlabeled_ranges: |
|
|
if end - start >= 10: |
|
|
suggested_start = start |
|
|
suggested_end = min(start+20, end) |
|
|
break |
|
|
else: |
|
|
|
|
|
start, end = unlabeled_ranges[0] |
|
|
suggested_start = start |
|
|
suggested_end = min(start+20, end) |
|
|
else: |
|
|
|
|
|
suggested_start = checked_count |
|
|
suggested_end = checked_count+20 |
|
|
|
|
|
analysis = f"""📊 **Dataset Progress: {dataset_repo}** {'(' + config + ')' if config else ''} |
|
|
|
|
|
**Labeled:** {len(labeled_ids)} trajectories |
|
|
- ✅ Keep: {keeps} ({keeps/len(labeled_ids)*100:.1f}%) |
|
|
- ❌ Remove: {removes} ({removes/len(labeled_ids)*100:.1f}%) |
|
|
- 🔍 Review: {reviews} ({reviews/len(labeled_ids)*100:.1f}%) |
|
|
|
|
|
**Dataset size:** {dataset_size} trajectories (checked: {checked_count}) |
|
|
""" |
|
|
|
|
|
if unlabeled_ranges[:3]: |
|
|
gaps = ", ".join([f"{s}-{e}" for s, e in unlabeled_ranges[:3]]) |
|
|
analysis += f"\n🎯 **Unlabeled gaps:** {gaps}" |
|
|
|
|
|
return analysis, suggested_start, suggested_end |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
suggested_start = len(labeled_ids) |
|
|
suggested_end = len(labeled_ids) + 20 |
|
|
|
|
|
|
|
|
try: |
|
|
from datasets import get_dataset_infos |
|
|
infos = get_dataset_infos(dataset_repo) |
|
|
config_key = config if config else list(infos.keys())[0] |
|
|
if config_key in infos and infos[config_key].splits.get('train'): |
|
|
ds_size = infos[config_key].splits['train'].num_examples |
|
|
size_info = f"**Dataset size:** {ds_size} trajectories\n\n" |
|
|
else: |
|
|
size_info = "" |
|
|
except: |
|
|
size_info = "" |
|
|
|
|
|
return f"""📊 **Dataset Progress: {dataset_repo}** |
|
|
|
|
|
**Labeled:** {len(labeled_ids)} trajectories |
|
|
- ✅ Keep: {keeps} ({keeps/len(labeled_ids)*100:.1f}%) |
|
|
- ❌ Remove: {removes} ({removes/len(labeled_ids)*100:.1f}%) |
|
|
- 🔍 Review: {reviews} ({reviews/len(labeled_ids)*100:.1f}%) |
|
|
|
|
|
{size_info}⚠️ Could not analyze dataset structure: {str(e)[:50]} |
|
|
""", suggested_start, suggested_end |
|
|
|
|
|
|
|
|
def get_video_url(dataset_repo, video_path): |
|
|
"""Get direct HuggingFace URL for video (no download needed).""" |
|
|
return f"https://huggingface.co/datasets/{dataset_repo}/resolve/main/{video_path}" |
|
|
|
|
|
|
|
|
def load_trajectories(dataset_repo, config_name, start_idx, end_idx, traj_id): |
|
|
"""Load trajectories by range or specific ID.""" |
|
|
global current_trajectories, current_idx |
|
|
|
|
|
|
|
|
load_evaluations() |
|
|
|
|
|
if not dataset_repo: |
|
|
return (gr.update(visible=True), gr.update(visible=False), |
|
|
None, "Select a dataset", |
|
|
gr.update(value=None), gr.update(visible=False), gr.update(value=None), gr.update(value=""), "", |
|
|
evaluations_df.tail(10), "⚠️ Select a dataset") |
|
|
|
|
|
config = config_name if config_name and config_name != "default" else None |
|
|
start = int(start_idx) if start_idx else 0 |
|
|
end = int(end_idx) if end_idx else start + 20 |
|
|
target_id = traj_id.strip() if traj_id else None |
|
|
|
|
|
try: |
|
|
ds = load_dataset(dataset_repo, config, split="train", streaming=True) |
|
|
|
|
|
current_trajectories = [] |
|
|
for i, sample in enumerate(ds): |
|
|
|
|
|
if target_id: |
|
|
if sample.get("id") == target_id: |
|
|
video_path = sample.get("frames") |
|
|
if video_path: |
|
|
sample["video_url"] = get_video_url(dataset_repo, video_path) |
|
|
sample["dataset_repo"] = dataset_repo |
|
|
sample["config_name"] = config |
|
|
current_trajectories.append(sample) |
|
|
break |
|
|
continue |
|
|
|
|
|
|
|
|
if i < start: |
|
|
continue |
|
|
if i > end: |
|
|
break |
|
|
|
|
|
video_path = sample.get("frames") |
|
|
if video_path: |
|
|
sample["video_url"] = get_video_url(dataset_repo, video_path) |
|
|
sample["dataset_repo"] = dataset_repo |
|
|
sample["config_name"] = config |
|
|
current_trajectories.append(sample) |
|
|
|
|
|
current_idx = 0 |
|
|
|
|
|
if not current_trajectories: |
|
|
return (gr.update(visible=True), gr.update(visible=False), |
|
|
None, "❌ No trajectories found", |
|
|
gr.update(value=None), gr.update(visible=False), gr.update(value=None), gr.update(value=""), "", |
|
|
evaluations_df.tail(10), "❌ No trajectories found") |
|
|
|
|
|
return show_labeling_view() |
|
|
|
|
|
except Exception as e: |
|
|
return (gr.update(visible=True), gr.update(visible=False), |
|
|
None, f"❌ {str(e)[:50]}", |
|
|
gr.update(value=None), gr.update(visible=False), gr.update(value=None), gr.update(value=""), "", |
|
|
evaluations_df.tail(10), f"❌ Error: {str(e)[:50]}") |
|
|
|
|
|
|
|
|
def get_trajectory_metadata(traj): |
|
|
"""Extract and format trajectory metadata.""" |
|
|
metadata = [] |
|
|
|
|
|
|
|
|
if 'quality_label' in traj and traj['quality_label']: |
|
|
label = str(traj['quality_label']).lower() |
|
|
if 'success' in label: |
|
|
metadata.append("✅ Success") |
|
|
elif 'fail' in label: |
|
|
metadata.append("❌ Failure") |
|
|
elif 'suboptimal' in label: |
|
|
metadata.append("⚠️ Suboptimal") |
|
|
else: |
|
|
metadata.append(f"Quality: {traj['quality_label']}") |
|
|
elif 'success' in traj: |
|
|
success = traj['success'] |
|
|
if success == True or success == 1 or success == 1.0: |
|
|
metadata.append("✅ Success") |
|
|
elif success == False or success == 0 or success == 0.0: |
|
|
metadata.append("❌ Failure") |
|
|
else: |
|
|
metadata.append(f"Status: {success}") |
|
|
elif 'is_success' in traj: |
|
|
if traj['is_success']: |
|
|
metadata.append("✅ Success") |
|
|
else: |
|
|
metadata.append("❌ Failure") |
|
|
|
|
|
|
|
|
if 'suboptimal' in traj and traj['suboptimal']: |
|
|
if "⚠️ Suboptimal" not in metadata: |
|
|
metadata.append("⚠️ Suboptimal") |
|
|
elif 'is_suboptimal' in traj and traj['is_suboptimal']: |
|
|
if "⚠️ Suboptimal" not in metadata: |
|
|
metadata.append("⚠️ Suboptimal") |
|
|
|
|
|
|
|
|
if 'is_robot' in traj: |
|
|
metadata.append("🤖 Robot" if traj['is_robot'] else "👤 Human") |
|
|
elif 'source' in traj: |
|
|
source = str(traj['source']).lower() |
|
|
if 'human' in source: |
|
|
metadata.append("👤 Human") |
|
|
elif 'robot' in source or 'policy' in source: |
|
|
metadata.append("🤖 Robot") |
|
|
else: |
|
|
metadata.append(f"Source: {traj['source']}") |
|
|
|
|
|
return " | ".join(metadata) if metadata else "" |
|
|
|
|
|
|
|
|
def show_labeling_view(): |
|
|
"""Switch to labeling view with first trajectory.""" |
|
|
traj = current_trajectories[current_idx] |
|
|
video_url = traj.get("video_url") |
|
|
task = traj.get("task", "No task") |
|
|
traj_id = traj.get("id", "Unknown") |
|
|
|
|
|
progress = f"Progress: {current_idx + 1}/{len(current_trajectories)} | ID: {traj_id[:8]}..." |
|
|
|
|
|
|
|
|
prev_decision = None |
|
|
prev_issue = None |
|
|
prev_notes = "" |
|
|
if traj_id in evaluations_df['trajectory_id'].values: |
|
|
prev_row = evaluations_df[evaluations_df['trajectory_id'] == traj_id].iloc[-1] |
|
|
prev_decision = prev_row['decision'] |
|
|
|
|
|
valid_choices = ['too_short', 'too_long', 'wrong_description', 'task_already_completed', 'mislabeled_success', 'mislabeled_failure', 'mislabeled_suboptimal', 'other'] |
|
|
issue_val = prev_row['issue_type'] |
|
|
prev_issue = issue_val if (issue_val and str(issue_val).strip() and str(issue_val) in valid_choices) else None |
|
|
prev_notes = prev_row['notes'] if pd.notna(prev_row['notes']) else "" |
|
|
progress += f" (prev: {prev_decision})" |
|
|
|
|
|
|
|
|
metadata = get_trajectory_metadata(traj) |
|
|
|
|
|
|
|
|
if metadata: |
|
|
task_with_progress = f"{progress}\n{metadata}\n\n{task}" |
|
|
else: |
|
|
task_with_progress = f"{progress}\n\n{task}" |
|
|
|
|
|
return ( |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=True), |
|
|
video_url, |
|
|
task_with_progress, |
|
|
gr.update(value=prev_decision), |
|
|
gr.update(visible=(prev_decision == "review")), |
|
|
gr.update(value=prev_issue), |
|
|
gr.update(value=prev_notes if prev_notes else ""), |
|
|
"", |
|
|
evaluations_df.tail(10), |
|
|
f"✅ Loaded {len(current_trajectories)} trajectories" |
|
|
) |
|
|
|
|
|
|
|
|
def show_current(): |
|
|
if not current_trajectories or current_idx >= len(current_trajectories): |
|
|
return ( |
|
|
None, |
|
|
"No data", |
|
|
gr.update(value=None), |
|
|
gr.update(visible=False), |
|
|
gr.update(value=None), |
|
|
gr.update(value="") |
|
|
) |
|
|
|
|
|
traj = current_trajectories[current_idx] |
|
|
video_url = traj.get("video_url") |
|
|
task = traj.get("task", "No task") |
|
|
traj_id = traj.get("id", "Unknown") |
|
|
|
|
|
progress = f"Progress: {current_idx + 1}/{len(current_trajectories)} | ID: {traj_id[:8]}..." |
|
|
|
|
|
|
|
|
prev_decision = None |
|
|
prev_issue = None |
|
|
prev_notes = "" |
|
|
if traj_id in evaluations_df['trajectory_id'].values: |
|
|
prev_row = evaluations_df[evaluations_df['trajectory_id'] == traj_id].iloc[-1] |
|
|
prev_decision = prev_row['decision'] |
|
|
|
|
|
valid_choices = ['too_short', 'too_long', 'wrong_description', 'task_already_completed', 'mislabeled_success', 'mislabeled_failure', 'mislabeled_suboptimal', 'other'] |
|
|
issue_val = prev_row['issue_type'] |
|
|
prev_issue = issue_val if (issue_val and str(issue_val).strip() and str(issue_val) in valid_choices) else None |
|
|
prev_notes = prev_row['notes'] if pd.notna(prev_row['notes']) else "" |
|
|
progress += f" (prev: {prev_decision})" |
|
|
|
|
|
|
|
|
metadata = get_trajectory_metadata(traj) |
|
|
|
|
|
|
|
|
if metadata: |
|
|
task_with_progress = f"{progress}\n{metadata}\n\n{task}" |
|
|
else: |
|
|
task_with_progress = f"{progress}\n\n{task}" |
|
|
|
|
|
return ( |
|
|
video_url, |
|
|
task_with_progress, |
|
|
gr.update(value=prev_decision), |
|
|
gr.update(visible=(prev_decision == "review")), |
|
|
gr.update(value=prev_issue), |
|
|
gr.update(value=prev_notes if prev_notes else "") |
|
|
) |
|
|
|
|
|
|
|
|
def navigate(direction): |
|
|
global current_idx |
|
|
if direction == "next": |
|
|
current_idx = min(current_idx + 1, len(current_trajectories) - 1) |
|
|
else: |
|
|
current_idx = max(current_idx - 1, 0) |
|
|
return show_current() |
|
|
|
|
|
|
|
|
def save_label(decision, issue_type="", notes=""): |
|
|
"""Save label and advance. Updates existing if trajectory already labeled.""" |
|
|
global evaluations_df, current_idx |
|
|
|
|
|
if not current_trajectories or current_idx >= len(current_trajectories): |
|
|
|
|
|
return show_current() + ("", evaluations_df.tail(10)) |
|
|
|
|
|
traj = current_trajectories[current_idx] |
|
|
traj_id = traj.get("id", "") |
|
|
|
|
|
|
|
|
print(f"💾 save_label called:") |
|
|
print(f" decision: {decision}") |
|
|
print(f" issue_type: '{issue_type}' (type: {type(issue_type)}, len: {len(str(issue_type))})") |
|
|
print(f" notes: {notes}") |
|
|
print(f" traj_id: {traj_id[:20]}...") |
|
|
|
|
|
row_data = { |
|
|
"dataset_repo": traj.get("dataset_repo", ""), |
|
|
"config_name": traj.get("config_name", ""), |
|
|
"trajectory_id": traj_id, |
|
|
"task": traj.get("task", ""), |
|
|
"decision": decision, |
|
|
"issue_type": issue_type, |
|
|
"notes": notes, |
|
|
"timestamp": pd.Timestamp.now().isoformat() |
|
|
} |
|
|
|
|
|
print(f" 📋 row_data issue_type: '{row_data['issue_type']}'") |
|
|
|
|
|
|
|
|
existing_mask = evaluations_df['trajectory_id'] == traj_id |
|
|
is_update = existing_mask.any() |
|
|
if is_update: |
|
|
idx = evaluations_df[existing_mask].index[-1] |
|
|
print(f" 🔄 Updating existing row at index {idx}") |
|
|
for col, val in row_data.items(): |
|
|
evaluations_df.at[idx, col] = val |
|
|
print(f" ✅ After update, issue_type = '{evaluations_df.at[idx, 'issue_type']}'") |
|
|
status_msg = f"✅ Updated: {decision}" |
|
|
else: |
|
|
print(f" ➕ Adding new row") |
|
|
evaluations_df = pd.concat([evaluations_df, pd.DataFrame([row_data])], ignore_index=True) |
|
|
new_idx = evaluations_df.index[-1] |
|
|
print(f" ✅ After add, issue_type = '{evaluations_df.at[new_idx, 'issue_type']}'") |
|
|
status_msg = f"✅ Added: {decision}" |
|
|
|
|
|
save_evaluations() |
|
|
print(f" 💾 Saved to CSV") |
|
|
|
|
|
if dataset_sync_enabled: |
|
|
from hf_dataset_sync import append_to_dataset |
|
|
append_to_dataset(row_data) |
|
|
|
|
|
current_idx = min(current_idx + 1, len(current_trajectories) - 1) |
|
|
|
|
|
return show_current() + (status_msg, evaluations_df.tail(10)) |
|
|
|
|
|
|
|
|
def save_with_decision(decision, review_reason, notes): |
|
|
|
|
|
valid_choices = ['too_short', 'too_long', 'wrong_description', 'task_already_completed', 'mislabeled_success', 'mislabeled_failure', 'mislabeled_suboptimal', 'other'] |
|
|
|
|
|
|
|
|
print(f"🔍 save_with_decision called:") |
|
|
print(f" decision: {decision} (type: {type(decision)})") |
|
|
print(f" review_reason: {review_reason} (type: {type(review_reason)})") |
|
|
print(f" notes: {notes}") |
|
|
|
|
|
issue = "" |
|
|
if decision == "review" and review_reason and str(review_reason) in valid_choices: |
|
|
issue = str(review_reason) |
|
|
print(f" ✅ Setting issue_type to: {issue}") |
|
|
else: |
|
|
print(f" ❌ Issue NOT set. Checks:") |
|
|
print(f" decision == 'review': {decision == 'review'}") |
|
|
print(f" review_reason truthy: {bool(review_reason)}") |
|
|
if review_reason: |
|
|
print(f" review_reason in valid_choices: {str(review_reason) in valid_choices}") |
|
|
|
|
|
return save_label(decision, issue, notes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def back_to_setup(): |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
|
|
|
def update_review_visibility(decision): |
|
|
return gr.update(visible=(decision == "review")) |
|
|
|
|
|
|
|
|
|
|
|
load_evaluations() |
|
|
|
|
|
|
|
|
css = """ |
|
|
.container { max-width: 1000px; margin: 0 auto; } |
|
|
.decision-btn { min-height: 50px !important; font-size: 16px !important; } |
|
|
.task-box { background: #f8f9fa; padding: 12px; border-radius: 6px; border-left: 4px solid #667eea; } |
|
|
.thin-back-btn button { |
|
|
min-height: 35px !important; |
|
|
font-size: 13px !important; |
|
|
margin-bottom: 8px !important; |
|
|
} |
|
|
#dataset_analysis { |
|
|
background: #f0f9ff; |
|
|
padding: 16px; |
|
|
border-radius: 8px; |
|
|
border-left: 4px solid #3b82f6; |
|
|
margin: 12px 0; |
|
|
} |
|
|
#save_status, #load_status { |
|
|
font-weight: 600; |
|
|
padding: 8px; |
|
|
border-radius: 6px; |
|
|
text-align: center; |
|
|
margin-top: 8px; |
|
|
} |
|
|
#save_status { |
|
|
color: #10b981; |
|
|
background: #d1fae5; |
|
|
} |
|
|
#load_status { |
|
|
color: #667eea; |
|
|
background: #e0e7ff; |
|
|
} |
|
|
#speed_1x, #speed_2x, #speed_4x { |
|
|
border: 2px solid #e5e7eb !important; |
|
|
transition: all 0.2s; |
|
|
} |
|
|
#speed_1x.speed-active, #speed_2x.speed-active, #speed_4x.speed-active { |
|
|
background: #667eea !important; |
|
|
color: white !important; |
|
|
border-color: #667eea !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Trajectory Reviewer", css=css) as demo: |
|
|
|
|
|
gr.Markdown("# 🎯 Trajectory Reviewer") |
|
|
|
|
|
|
|
|
with gr.Column(visible=True) as setup_view: |
|
|
gr.Markdown("### Dataset") |
|
|
|
|
|
with gr.Row(): |
|
|
dataset_dropdown = gr.Dropdown( |
|
|
choices=PREDEFINED_DATASETS, |
|
|
value="jesbu1/epic_rfm", |
|
|
label="Dataset", |
|
|
allow_custom_value=True, |
|
|
scale=3 |
|
|
) |
|
|
refresh_btn = gr.Button("🔄", scale=0) |
|
|
|
|
|
config_dropdown = gr.Dropdown( |
|
|
choices=[], |
|
|
value="", |
|
|
label="Config", |
|
|
allow_custom_value=True |
|
|
) |
|
|
|
|
|
dataset_analysis = gr.Markdown("", elem_id="dataset_analysis") |
|
|
|
|
|
gr.Markdown("### Selection") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
start_idx = gr.Number(label="Start Index", value=0, precision=0) |
|
|
end_idx = gr.Number(label="End Index", value=20, precision=0) |
|
|
with gr.Column(): |
|
|
traj_id_input = gr.Textbox(label="Or Specific ID", placeholder="Leave empty for range") |
|
|
|
|
|
load_btn = gr.Button("🚀 Load & Start", variant="primary", size="lg") |
|
|
load_status = gr.Markdown("", elem_id="load_status") |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as labeling_view: |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
back_btn = gr.Button("← Back to Setup", variant="secondary", size="sm", elem_classes=["thin-back-btn"]) |
|
|
video_player = gr.Video(label="Video", elem_id="traj_video", autoplay=True) |
|
|
|
|
|
|
|
|
gr.Markdown("**Playback Speed**") |
|
|
with gr.Row(): |
|
|
speed_1x = gr.Button("1x", size="sm", elem_id="speed_1x") |
|
|
speed_2x = gr.Button("2x", size="sm", elem_id="speed_2x") |
|
|
speed_4x = gr.Button("4x", size="sm", elem_id="speed_4x") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
task_display = gr.Textbox(label="📋 Task", interactive=False, lines=3, elem_classes=["task-box"]) |
|
|
|
|
|
gr.Markdown("### Decision") |
|
|
decision_radio = gr.Radio( |
|
|
choices=["keep", "remove", "review"], |
|
|
label="Select", |
|
|
value=None |
|
|
) |
|
|
|
|
|
with gr.Column(visible=False) as review_options: |
|
|
review_reason = gr.Radio( |
|
|
choices=[ |
|
|
"too_short", |
|
|
"too_long", |
|
|
"wrong_description", |
|
|
"task_already_completed", |
|
|
"mislabeled_success", |
|
|
"mislabeled_failure", |
|
|
"mislabeled_suboptimal", |
|
|
"other" |
|
|
], |
|
|
label="Review Reason", |
|
|
value=None |
|
|
) |
|
|
|
|
|
notes_input = gr.Textbox(label="Notes", placeholder="Optional...", lines=2) |
|
|
save_btn = gr.Button("💾 Save & Next", variant="primary", size="lg", elem_classes=["decision-btn"]) |
|
|
save_status = gr.Markdown("", elem_id="save_status") |
|
|
|
|
|
with gr.Row(): |
|
|
prev_btn = gr.Button("← Prev", size="sm") |
|
|
next_btn = gr.Button("Next →", size="sm") |
|
|
|
|
|
gr.Markdown("### Recent Labels") |
|
|
evals_table = gr.Dataframe( |
|
|
value=evaluations_df.tail(10), |
|
|
max_height=150 |
|
|
) |
|
|
|
|
|
|
|
|
def set_speed_js(rate, btn_id): |
|
|
return f""" |
|
|
() => {{ |
|
|
const setSpeed = () => {{ |
|
|
const video = document.querySelector('#traj_video video'); |
|
|
if (video) {{ |
|
|
video.playbackRate = {rate}; |
|
|
// Highlight active button |
|
|
['#speed_1x', '#speed_2x', '#speed_4x'].forEach(id => {{ |
|
|
const btn = document.querySelector(id); |
|
|
if (btn) btn.classList.remove('speed-active'); |
|
|
}}); |
|
|
document.querySelector('{btn_id}')?.classList.add('speed-active'); |
|
|
}} |
|
|
}}; |
|
|
setSpeed(); |
|
|
// Also set on video load events |
|
|
setTimeout(setSpeed, 100); |
|
|
setTimeout(setSpeed, 500); |
|
|
}} |
|
|
""" |
|
|
|
|
|
speed_1x.click(None, None, None, js=set_speed_js(1.0, '#speed_1x')) |
|
|
speed_2x.click(None, None, None, js=set_speed_js(2.0, '#speed_2x')) |
|
|
speed_4x.click(None, None, None, js=set_speed_js(4.0, '#speed_4x')) |
|
|
|
|
|
|
|
|
video_player.change( |
|
|
None, None, None, |
|
|
js="() => { setTimeout(() => { const v = document.querySelector('#traj_video video'); if (v) v.playbackRate = 2.0; }, 500); }" |
|
|
) |
|
|
|
|
|
|
|
|
dataset_dropdown.change(fetch_configs, [dataset_dropdown], [config_dropdown, dataset_analysis, start_idx, end_idx]) |
|
|
refresh_btn.click(fetch_configs, [dataset_dropdown], [config_dropdown, dataset_analysis, start_idx, end_idx]) |
|
|
config_dropdown.change(analyze_dataset_progress, [dataset_dropdown, config_dropdown], [dataset_analysis, start_idx, end_idx]) |
|
|
|
|
|
load_btn.click( |
|
|
lambda: "⏳ Loading trajectories...", |
|
|
None, |
|
|
load_status |
|
|
).then( |
|
|
load_trajectories, |
|
|
[dataset_dropdown, config_dropdown, start_idx, end_idx, traj_id_input], |
|
|
[setup_view, labeling_view, video_player, task_display, |
|
|
decision_radio, review_options, review_reason, notes_input, save_status, |
|
|
evals_table, load_status] |
|
|
) |
|
|
|
|
|
back_btn.click(back_to_setup, outputs=[setup_view, labeling_view]) |
|
|
decision_radio.change(update_review_visibility, [decision_radio], [review_options]) |
|
|
|
|
|
save_btn.click( |
|
|
save_with_decision, |
|
|
[decision_radio, review_reason, notes_input], |
|
|
[video_player, task_display, decision_radio, review_options, |
|
|
review_reason, notes_input, save_status, evals_table] |
|
|
).then( |
|
|
None, None, save_status, |
|
|
js="() => { setTimeout(() => document.querySelector('#save_status').textContent = '', 3000); }" |
|
|
) |
|
|
|
|
|
prev_btn.click( |
|
|
lambda: navigate("prev"), |
|
|
outputs=[video_player, task_display, decision_radio, |
|
|
review_options, review_reason, notes_input] |
|
|
) |
|
|
next_btn.click( |
|
|
lambda: navigate("next"), |
|
|
outputs=[video_player, task_display, decision_radio, |
|
|
review_options, review_reason, notes_input] |
|
|
) |
|
|
|
|
|
|
|
|
demo.load(fetch_configs, [dataset_dropdown], [config_dropdown, dataset_analysis, start_idx, end_idx]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|