from __future__ import annotations import json import os import tempfile import time from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any import cv2 import fsspec import gradio as gr import numpy as np import pyarrow.parquet as pq from huggingface_hub import hf_hub_download GOLD_REPO = os.environ.get("WASUP_GOLD_REPO", "macrodata/whats_going_on_bench") GOLD_PARQUET = os.environ.get("WASUP_GOLD_PARQUET", "whats_going_on_bench.parquet") GOLD_INDEX = os.environ.get("WASUP_GOLD_INDEX", "whats_going_on_bench_index.parquet") GOLD_PARQUET_URL = f"https://huggingface.co/datasets/{GOLD_REPO}/resolve/main/{GOLD_PARQUET}" CACHE_DIR = Path(tempfile.gettempdir()) / "wasup_gold_annotator" CACHE_DIR.mkdir(parents=True, exist_ok=True) ANNOTATION_PATH = Path( os.environ.get( "WASUP_ANNOTATION_PATH", "/data/wasup_gold_annotations.jsonl" if Path("/data").exists() else str(CACHE_DIR / "wasup_gold_annotations.jsonl"), ) ) @dataclass(frozen=True) class Segment: start_sec: float end_sec: float label: str def valid(self) -> bool: return bool(self.label.strip()) and self.end_sec > self.start_sec def to_dict(self) -> dict[str, Any]: return { "start_sec": round(float(self.start_sec), 3), "end_sec": round(float(self.end_sec), 3), "subtask": self.label.strip(), } @lru_cache(maxsize=1) def load_gold_index() -> list[dict[str, Any]]: path = hf_hub_download(repo_id=GOLD_REPO, repo_type="dataset", filename=GOLD_INDEX) rows = pq.read_table(path).to_pylist() return sorted(rows, key=lambda row: (str(row.get("source_dataset") or ""), str(row.get("bench_id") or ""))) @lru_cache(maxsize=1) def gold_parquet() -> pq.ParquetFile: fs = fsspec.filesystem("https", block_size=2**20) return pq.ParquetFile(fs.open(GOLD_PARQUET_URL, "rb")) @lru_cache(maxsize=512) def load_gold_row(bench_id: str) -> dict[str, Any]: index = {row["bench_id"]: row for row in load_gold_index()} if bench_id not in index: raise KeyError(f"{bench_id} is not in {GOLD_REPO}") row_group = int(index[bench_id]["row_group"]) return gold_parquet().read_row_group(row_group).to_pylist()[0] def episode_choices() -> list[str]: choices = [] for idx, row in enumerate(load_gold_index()): bench_id = row.get("bench_id") or "" source = row.get("source_dataset") or bench_id.split("_", 1)[0] instruction = row.get("instruction") or "" choices.append(f"{idx:04d} | {source} | {bench_id} | {instruction}") return choices def bench_id_from_choice(choice: str | None) -> str: if not choice: return str(load_gold_index()[0]["bench_id"]) parts = choice.split("|") if len(parts) >= 3: return parts[2].strip() return str(load_gold_index()[0]["bench_id"]) def step_episode(choice: str | None, delta: int) -> str: items = episode_choices() if not items: return "" if choice in items: idx = items.index(choice) else: idx = 0 return items[max(0, min(idx + delta, len(items) - 1))] def write_video(row: dict[str, Any]) -> str: digest = str(row.get("primary_video_sha256") or row["bench_id"]) path = CACHE_DIR / f'{row["bench_id"]}_{digest[:12]}.mp4' if not path.exists(): path.write_bytes(bytes(row["primary_video"])) return str(path) def segment_from_mapping(item: dict[str, Any], fps: float | None = None) -> Segment | None: label = str(item.get("subtask") or item.get("label") or item.get("text") or "").strip() try: if "start_sec" in item and "end_sec" in item: return Segment(float(item["start_sec"]), float(item["end_sec"]), label) if fps and "start_frame" in item and "end_frame" in item: return Segment(float(item["start_frame"]) / fps, float(item["end_frame"]) / fps, label) except (TypeError, ValueError): return None return None def collapse_frame_annotations(frames: list[dict[str, Any]], fps: float) -> list[Segment]: if not frames: return [] ordered = sorted(frames, key=lambda item: int(item.get("frame_index") or 0)) segments: list[Segment] = [] start = int(ordered[0].get("frame_index") or 0) previous = start label = str(ordered[0].get("subtask") or "") for item in ordered[1:]: frame = int(item.get("frame_index") or 0) next_label = str(item.get("subtask") or "") if next_label != label: segment = Segment(start / fps, (previous + 1) / fps, label) if segment.valid(): segments.append(segment) start = frame label = next_label previous = frame final = Segment(start / fps, (previous + 1) / fps, label) if final.valid(): segments.append(final) return segments def gold_segments(row: dict[str, Any]) -> list[Segment]: fps = float(row.get("fps") or 10.0) explicit = [] for item in row.get("subtasks") or []: segment = segment_from_mapping(item, fps=fps) if segment and segment.valid(): explicit.append(segment) return explicit or collapse_frame_annotations(row.get("frame_annotations") or [], fps=fps) def table_from_segments(segments: list[Segment]) -> list[list[Any]]: return [ [idx, round(seg.start_sec, 3), round(seg.end_sec, 3), round(seg.end_sec - seg.start_sec, 3), seg.label] for idx, seg in enumerate(sorted(segments, key=lambda item: (item.start_sec, item.end_sec))) ] def normalize_table(table: Any) -> list[list[Any]]: if table is None: return [] if hasattr(table, "values"): rows = table.values.tolist() elif isinstance(table, dict) and "data" in table: rows = table["data"] else: rows = table normalized = [] for idx, row in enumerate(rows or []): if not row or len(row) < 5: continue try: start = float(row[1]) end = float(row[2]) except (TypeError, ValueError): continue label = str(row[4] or "").strip() if end > start and label: normalized.append([idx, start, end, round(end - start, 3), label]) return normalized def segments_from_table(table: Any) -> list[Segment]: return [ Segment(float(row[1]), float(row[2]), str(row[4])) for row in normalize_table(table) ] def render_timeline(table: Any, duration_sec: float) -> str: rows = normalize_table(table) duration = max(float(duration_sec or 1.0), 1.0) colors = ["#bfdbfe", "#bbf7d0", "#fde68a", "#fecaca", "#ddd6fe", "#bae6fd"] parts = [ "
", "
", ] for idx, row in enumerate(rows): start = max(0.0, min(float(row[1]), duration)) end = max(start, min(float(row[2]), duration)) left = 100 * start / duration width = max(0.8, 100 * (end - start) / duration) label = html_escape(str(row[4])) parts.append( f"
{idx}: {label}
" ) parts.append("
") return "".join(parts) def html_escape(value: str) -> str: return ( value.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def render_frame(video_path: str, frame_index: int) -> np.ndarray | None: capture = cv2.VideoCapture(video_path) capture.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index)) ok, frame = capture.read() capture.release() if not ok or frame is None: return None return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def active_segment(table: Any, time_sec: float) -> dict[str, Any]: for row in normalize_table(table): if float(row[1]) <= time_sec < float(row[2]): return { "segment_index": int(row[0]), "start_sec": float(row[1]), "end_sec": float(row[2]), "subtask": str(row[4]), } return {"segment_index": None, "time_sec": time_sec, "subtask": ""} def load_episode(choice: str | None) -> tuple[Any, ...]: bench_id = bench_id_from_choice(choice) row = load_gold_row(bench_id) video_path = write_video(row) fps = float(row.get("fps") or 10.0) num_frames = int(row.get("num_frames") or 1) duration = num_frames / fps if fps else float(row.get("duration_sec") or 1.0) table = table_from_segments(gold_segments(row)) frame = 0 summary = { "bench_id": bench_id, "source_dataset": row.get("source_dataset"), "instruction": row.get("instruction"), "fps": fps, "num_frames": num_frames, "duration_sec": duration, "old_segments": len(table), } return ( row, video_path, json.dumps(summary, indent=2), table, render_timeline(table, duration), gr.Slider(value=frame, minimum=0, maximum=max(0, num_frames - 1), step=1), render_frame(video_path, frame), active_segment(table, 0.0), 0, table[0][1] if table else 0.0, table[0][2] if table else 1.0, table[0][4] if table else "", f"Loaded {bench_id}", ) def refresh_frame(row: dict[str, Any], table: Any, frame_index: int) -> tuple[Any, dict[str, Any]]: if not row: return None, {} video_path = write_video(row) fps = float(row.get("fps") or 10.0) frame_index = max(0, min(int(frame_index or 0), int(row.get("num_frames") or 1) - 1)) time_sec = frame_index / fps if fps else 0.0 return render_frame(video_path, frame_index), active_segment(table, time_sec) def step_frame(row: dict[str, Any], table: Any, frame_index: int, delta: int) -> tuple[int, Any, dict[str, Any]]: if not row: return 0, None, {} max_frame = max(0, int(row.get("num_frames") or 1) - 1) next_frame = max(0, min(int(frame_index or 0) + delta, max_frame)) image, active = refresh_frame(row, table, next_frame) return next_frame, image, active def time_from_frame(row: dict[str, Any], frame_index: int) -> float: fps = float((row or {}).get("fps") or 10.0) return round(float(frame_index or 0) / fps, 3) if fps else 0.0 def set_start(row: dict[str, Any], frame_index: int, end_sec: float, label: str) -> tuple[float, float, str]: return time_from_frame(row, frame_index), float(end_sec or 0.0), label def set_end(row: dict[str, Any], frame_index: int, start_sec: float, label: str) -> tuple[float, float, str]: return float(start_sec or 0.0), time_from_frame(row, frame_index), label def load_segment_to_editor(table: Any, segment_index: int) -> tuple[float, float, str, str]: rows = normalize_table(table) idx = int(segment_index or 0) if idx < 0 or idx >= len(rows): return 0.0, 1.0, "", f"Segment {idx} is out of range." row = rows[idx] return float(row[1]), float(row[2]), str(row[4]), f"Loaded segment {idx}." def upsert_segment(row_state: dict[str, Any], table: Any, segment_index: int, start_sec: float, end_sec: float, label: str) -> tuple[Any, str, str]: rows = normalize_table(table) idx = int(segment_index if segment_index is not None else len(rows)) label = str(label or "").strip() try: start = float(start_sec) end = float(end_sec) except (TypeError, ValueError): return rows, timeline_for_state(row_state, rows), "Start/end must be numbers." if not label: return rows, timeline_for_state(row_state, rows), "Subtask label is required." if end <= start: return rows, timeline_for_state(row_state, rows), "end_sec must be greater than start_sec." new_row = [idx, start, end, round(end - start, 3), label] if 0 <= idx < len(rows): rows[idx] = new_row message = f"Updated segment {idx}." else: rows.append(new_row) message = "Added segment." rows = table_from_segments(segments_from_table(rows)) return rows, timeline_for_state(row_state, rows), message def delete_segment(row_state: dict[str, Any], table: Any, segment_index: int) -> tuple[Any, str, str]: rows = normalize_table(table) idx = int(segment_index or 0) if not (0 <= idx < len(rows)): return rows, timeline_for_state(row_state, rows), f"Segment {idx} is out of range." del rows[idx] rows = table_from_segments(segments_from_table(rows)) return rows, timeline_for_state(row_state, rows), f"Deleted segment {idx}." def sort_segments(row_state: dict[str, Any], table: Any) -> tuple[Any, str, str]: rows = table_from_segments(segments_from_table(table)) return rows, timeline_for_state(row_state, rows), "Sorted segments." def timeline_for_state(row_state: dict[str, Any], table: Any) -> str: if not row_state: return "" fps = float(row_state.get("fps") or 10.0) duration = int(row_state.get("num_frames") or 1) / fps if fps else 1.0 return render_timeline(table, duration) def save_episode(row_state: dict[str, Any], table: Any, annotator: str, notes: str) -> tuple[str, str]: if not row_state: return "No episode loaded.", str(ANNOTATION_PATH) segments = [segment.to_dict() for segment in segments_from_table(table)] payload = { "saved_at_unix": time.time(), "annotator": str(annotator or "").strip(), "bench_id": row_state.get("bench_id"), "source_dataset": row_state.get("source_dataset"), "instruction": row_state.get("instruction"), "old_subtasks": [segment.to_dict() for segment in gold_segments(row_state)], "new_subtasks": segments, "notes": str(notes or "").strip(), } ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True) with ANNOTATION_PATH.open("a", encoding="utf-8") as handle: handle.write(json.dumps(payload, sort_keys=True) + "\n") return f"Saved {row_state.get('bench_id')} with {len(segments)} segments.", str(ANNOTATION_PATH) def download_annotations() -> str: ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True) ANNOTATION_PATH.touch(exist_ok=True) return str(ANNOTATION_PATH) with gr.Blocks(title="Wasup Gold Annotator") as demo: gr.Markdown("# Wasup Gold Annotator") gr.Markdown("Edit temporal subtask segments. Existing gold is loaded as the starting point; saves append revised annotations to JSONL.") row_state = gr.State({}) minus_one = gr.State(-1) plus_one = gr.State(1) choices_box = gr.Dropdown(label="Episode", choices=episode_choices(), value=episode_choices()[0] if episode_choices() else None) with gr.Row(): prev_episode = gr.Button("Previous episode") load_episode_btn = gr.Button("Load episode", variant="primary") next_episode = gr.Button("Next episode") with gr.Row(): video = gr.Video(label="Episode video") with gr.Column(): summary = gr.Code(label="Episode metadata", language="json") status = gr.Textbox(label="Status", interactive=False) timeline = gr.HTML(label="Timeline") segments = gr.Dataframe( label="Editable segments", headers=["idx", "start_sec", "end_sec", "duration", "subtask"], datatype=["number", "number", "number", "number", "str"], interactive=True, ) with gr.Row(): frame_prev = gr.Button("Previous frame") frame_next = gr.Button("Next frame") frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, value=0, step=1) with gr.Row(): frame_image = gr.Image(label="Selected frame", interactive=False) active = gr.JSON(label="Current segment") with gr.Row(): segment_index = gr.Number(label="Segment idx", value=0, precision=0) start_sec = gr.Number(label="start_sec", value=0.0) end_sec = gr.Number(label="end_sec", value=1.0) subtask = gr.Textbox(label="subtask") with gr.Row(): load_segment = gr.Button("Load selected segment") set_start_btn = gr.Button("Set start = current frame") set_end_btn = gr.Button("Set end = current frame") with gr.Row(): upsert_btn = gr.Button("Add/update segment", variant="primary") delete_btn = gr.Button("Delete segment") sort_btn = gr.Button("Sort segments") with gr.Row(): annotator = gr.Textbox(label="Annotator") notes = gr.Textbox(label="Notes") with gr.Row(): save_btn = gr.Button("Save episode annotation", variant="primary") download_btn = gr.Button("Prepare JSONL download") download_file = gr.File(label="Annotations JSONL") episode_outputs = [ row_state, video, summary, segments, timeline, frame_slider, frame_image, active, segment_index, start_sec, end_sec, subtask, status, ] load_episode_btn.click(load_episode, inputs=choices_box, outputs=episode_outputs) choices_box.change(load_episode, inputs=choices_box, outputs=episode_outputs) prev_episode.click(step_episode, inputs=[choices_box, minus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs) next_episode.click(step_episode, inputs=[choices_box, plus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs) frame_slider.change(refresh_frame, inputs=[row_state, segments, frame_slider], outputs=[frame_image, active]) frame_prev.click(step_frame, inputs=[row_state, segments, frame_slider, minus_one], outputs=[frame_slider, frame_image, active]) frame_next.click(step_frame, inputs=[row_state, segments, frame_slider, plus_one], outputs=[frame_slider, frame_image, active]) load_segment.click(load_segment_to_editor, inputs=[segments, segment_index], outputs=[start_sec, end_sec, subtask, status]) set_start_btn.click(set_start, inputs=[row_state, frame_slider, end_sec, subtask], outputs=[start_sec, end_sec, subtask]) set_end_btn.click(set_end, inputs=[row_state, frame_slider, start_sec, subtask], outputs=[start_sec, end_sec, subtask]) upsert_btn.click(upsert_segment, inputs=[row_state, segments, segment_index, start_sec, end_sec, subtask], outputs=[segments, timeline, status]) delete_btn.click(delete_segment, inputs=[row_state, segments, segment_index], outputs=[segments, timeline, status]) sort_btn.click(sort_segments, inputs=[row_state, segments], outputs=[segments, timeline, status]) save_btn.click(save_episode, inputs=[row_state, segments, annotator, notes], outputs=[status, download_file]) download_btn.click(download_annotations, outputs=download_file) demo.load(load_episode, inputs=choices_box, outputs=episode_outputs) if __name__ == "__main__": demo.launch()