Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| import cv2 | |
| import fsspec | |
| import gradio as gr | |
| import numpy as np | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import hf_hub_download | |
| GOLD_REPO = os.environ.get("WASUP_GOLD_REPO", "macrodata/whats_going_on_bench") | |
| GOLD_PARQUET = os.environ.get("WASUP_GOLD_PARQUET", "whats_going_on_bench.parquet") | |
| GOLD_INDEX = os.environ.get("WASUP_GOLD_INDEX", "whats_going_on_bench_index.parquet") | |
| GOLD_PARQUET_URL = f"https://huggingface.co/datasets/{GOLD_REPO}/resolve/main/{GOLD_PARQUET}" | |
| CACHE_DIR = Path(tempfile.gettempdir()) / "wasup_gold_annotator" | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| ANNOTATION_PATH = Path( | |
| os.environ.get( | |
| "WASUP_ANNOTATION_PATH", | |
| "/data/wasup_gold_annotations.jsonl" | |
| if Path("/data").exists() | |
| else str(CACHE_DIR / "wasup_gold_annotations.jsonl"), | |
| ) | |
| ) | |
| class Segment: | |
| start_sec: float | |
| end_sec: float | |
| label: str | |
| def valid(self) -> bool: | |
| return bool(self.label.strip()) and self.end_sec > self.start_sec | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "start_sec": round(float(self.start_sec), 3), | |
| "end_sec": round(float(self.end_sec), 3), | |
| "subtask": self.label.strip(), | |
| } | |
| def load_gold_index() -> list[dict[str, Any]]: | |
| path = hf_hub_download(repo_id=GOLD_REPO, repo_type="dataset", filename=GOLD_INDEX) | |
| rows = pq.read_table(path).to_pylist() | |
| return sorted(rows, key=lambda row: (str(row.get("source_dataset") or ""), str(row.get("bench_id") or ""))) | |
| def gold_parquet() -> pq.ParquetFile: | |
| fs = fsspec.filesystem("https", block_size=2**20) | |
| return pq.ParquetFile(fs.open(GOLD_PARQUET_URL, "rb")) | |
| def load_gold_row(bench_id: str) -> dict[str, Any]: | |
| index = {row["bench_id"]: row for row in load_gold_index()} | |
| if bench_id not in index: | |
| raise KeyError(f"{bench_id} is not in {GOLD_REPO}") | |
| row_group = int(index[bench_id]["row_group"]) | |
| return gold_parquet().read_row_group(row_group).to_pylist()[0] | |
| def episode_choices() -> list[str]: | |
| choices = [] | |
| for idx, row in enumerate(load_gold_index()): | |
| bench_id = row.get("bench_id") or "" | |
| source = row.get("source_dataset") or bench_id.split("_", 1)[0] | |
| instruction = row.get("instruction") or "" | |
| choices.append(f"{idx:04d} | {source} | {bench_id} | {instruction}") | |
| return choices | |
| def bench_id_from_choice(choice: str | None) -> str: | |
| if not choice: | |
| return str(load_gold_index()[0]["bench_id"]) | |
| parts = choice.split("|") | |
| if len(parts) >= 3: | |
| return parts[2].strip() | |
| return str(load_gold_index()[0]["bench_id"]) | |
| def step_episode(choice: str | None, delta: int) -> str: | |
| items = episode_choices() | |
| if not items: | |
| return "" | |
| if choice in items: | |
| idx = items.index(choice) | |
| else: | |
| idx = 0 | |
| return items[max(0, min(idx + delta, len(items) - 1))] | |
| def write_video(row: dict[str, Any]) -> str: | |
| digest = str(row.get("primary_video_sha256") or row["bench_id"]) | |
| path = CACHE_DIR / f'{row["bench_id"]}_{digest[:12]}.mp4' | |
| if not path.exists(): | |
| path.write_bytes(bytes(row["primary_video"])) | |
| return str(path) | |
| def segment_from_mapping(item: dict[str, Any], fps: float | None = None) -> Segment | None: | |
| label = str(item.get("subtask") or item.get("label") or item.get("text") or "").strip() | |
| try: | |
| if "start_sec" in item and "end_sec" in item: | |
| return Segment(float(item["start_sec"]), float(item["end_sec"]), label) | |
| if fps and "start_frame" in item and "end_frame" in item: | |
| return Segment(float(item["start_frame"]) / fps, float(item["end_frame"]) / fps, label) | |
| except (TypeError, ValueError): | |
| return None | |
| return None | |
| def collapse_frame_annotations(frames: list[dict[str, Any]], fps: float) -> list[Segment]: | |
| if not frames: | |
| return [] | |
| ordered = sorted(frames, key=lambda item: int(item.get("frame_index") or 0)) | |
| segments: list[Segment] = [] | |
| start = int(ordered[0].get("frame_index") or 0) | |
| previous = start | |
| label = str(ordered[0].get("subtask") or "") | |
| for item in ordered[1:]: | |
| frame = int(item.get("frame_index") or 0) | |
| next_label = str(item.get("subtask") or "") | |
| if next_label != label: | |
| segment = Segment(start / fps, (previous + 1) / fps, label) | |
| if segment.valid(): | |
| segments.append(segment) | |
| start = frame | |
| label = next_label | |
| previous = frame | |
| final = Segment(start / fps, (previous + 1) / fps, label) | |
| if final.valid(): | |
| segments.append(final) | |
| return segments | |
| def gold_segments(row: dict[str, Any]) -> list[Segment]: | |
| fps = float(row.get("fps") or 10.0) | |
| explicit = [] | |
| for item in row.get("subtasks") or []: | |
| segment = segment_from_mapping(item, fps=fps) | |
| if segment and segment.valid(): | |
| explicit.append(segment) | |
| return explicit or collapse_frame_annotations(row.get("frame_annotations") or [], fps=fps) | |
| def table_from_segments(segments: list[Segment]) -> list[list[Any]]: | |
| return [ | |
| [idx, round(seg.start_sec, 3), round(seg.end_sec, 3), round(seg.end_sec - seg.start_sec, 3), seg.label] | |
| for idx, seg in enumerate(sorted(segments, key=lambda item: (item.start_sec, item.end_sec))) | |
| ] | |
| def normalize_table(table: Any) -> list[list[Any]]: | |
| if table is None: | |
| return [] | |
| if hasattr(table, "values"): | |
| rows = table.values.tolist() | |
| elif isinstance(table, dict) and "data" in table: | |
| rows = table["data"] | |
| else: | |
| rows = table | |
| normalized = [] | |
| for idx, row in enumerate(rows or []): | |
| if not row or len(row) < 5: | |
| continue | |
| try: | |
| start = float(row[1]) | |
| end = float(row[2]) | |
| except (TypeError, ValueError): | |
| continue | |
| label = str(row[4] or "").strip() | |
| if end > start and label: | |
| normalized.append([idx, start, end, round(end - start, 3), label]) | |
| return normalized | |
| def segments_from_table(table: Any) -> list[Segment]: | |
| return [ | |
| Segment(float(row[1]), float(row[2]), str(row[4])) | |
| for row in normalize_table(table) | |
| ] | |
| def render_timeline(table: Any, duration_sec: float) -> str: | |
| rows = normalize_table(table) | |
| duration = max(float(duration_sec or 1.0), 1.0) | |
| colors = ["#bfdbfe", "#bbf7d0", "#fde68a", "#fecaca", "#ddd6fe", "#bae6fd"] | |
| parts = [ | |
| "<div style='font-family:system-ui,sans-serif'>", | |
| "<div style='position:relative;height:54px;background:#f8fafc;border:1px solid #cbd5e1;border-radius:6px;overflow:hidden'>", | |
| ] | |
| for idx, row in enumerate(rows): | |
| start = max(0.0, min(float(row[1]), duration)) | |
| end = max(start, min(float(row[2]), duration)) | |
| left = 100 * start / duration | |
| width = max(0.8, 100 * (end - start) / duration) | |
| label = html_escape(str(row[4])) | |
| parts.append( | |
| f"<div title='{label}' style='position:absolute;left:{left:.2f}%;width:{width:.2f}%;top:7px;height:40px;" | |
| f"background:{colors[idx % len(colors)]};border:1px solid #64748b;border-radius:5px;" | |
| f"font-size:12px;line-height:15px;padding:4px;box-sizing:border-box;overflow:hidden'>{idx}: {label}</div>" | |
| ) | |
| parts.append("</div></div>") | |
| return "".join(parts) | |
| def html_escape(value: str) -> str: | |
| return ( | |
| value.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def render_frame(video_path: str, frame_index: int) -> np.ndarray | None: | |
| capture = cv2.VideoCapture(video_path) | |
| capture.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index)) | |
| ok, frame = capture.read() | |
| capture.release() | |
| if not ok or frame is None: | |
| return None | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| def active_segment(table: Any, time_sec: float) -> dict[str, Any]: | |
| for row in normalize_table(table): | |
| if float(row[1]) <= time_sec < float(row[2]): | |
| return { | |
| "segment_index": int(row[0]), | |
| "start_sec": float(row[1]), | |
| "end_sec": float(row[2]), | |
| "subtask": str(row[4]), | |
| } | |
| return {"segment_index": None, "time_sec": time_sec, "subtask": ""} | |
| def load_episode(choice: str | None) -> tuple[Any, ...]: | |
| bench_id = bench_id_from_choice(choice) | |
| row = load_gold_row(bench_id) | |
| video_path = write_video(row) | |
| fps = float(row.get("fps") or 10.0) | |
| num_frames = int(row.get("num_frames") or 1) | |
| duration = num_frames / fps if fps else float(row.get("duration_sec") or 1.0) | |
| table = table_from_segments(gold_segments(row)) | |
| frame = 0 | |
| summary = { | |
| "bench_id": bench_id, | |
| "source_dataset": row.get("source_dataset"), | |
| "instruction": row.get("instruction"), | |
| "fps": fps, | |
| "num_frames": num_frames, | |
| "duration_sec": duration, | |
| "old_segments": len(table), | |
| } | |
| return ( | |
| row, | |
| video_path, | |
| json.dumps(summary, indent=2), | |
| table, | |
| render_timeline(table, duration), | |
| gr.Slider(value=frame, minimum=0, maximum=max(0, num_frames - 1), step=1), | |
| render_frame(video_path, frame), | |
| active_segment(table, 0.0), | |
| 0, | |
| table[0][1] if table else 0.0, | |
| table[0][2] if table else 1.0, | |
| table[0][4] if table else "", | |
| f"Loaded {bench_id}", | |
| ) | |
| def refresh_frame(row: dict[str, Any], table: Any, frame_index: int) -> tuple[Any, dict[str, Any]]: | |
| if not row: | |
| return None, {} | |
| video_path = write_video(row) | |
| fps = float(row.get("fps") or 10.0) | |
| frame_index = max(0, min(int(frame_index or 0), int(row.get("num_frames") or 1) - 1)) | |
| time_sec = frame_index / fps if fps else 0.0 | |
| return render_frame(video_path, frame_index), active_segment(table, time_sec) | |
| def step_frame(row: dict[str, Any], table: Any, frame_index: int, delta: int) -> tuple[int, Any, dict[str, Any]]: | |
| if not row: | |
| return 0, None, {} | |
| max_frame = max(0, int(row.get("num_frames") or 1) - 1) | |
| next_frame = max(0, min(int(frame_index or 0) + delta, max_frame)) | |
| image, active = refresh_frame(row, table, next_frame) | |
| return next_frame, image, active | |
| def time_from_frame(row: dict[str, Any], frame_index: int) -> float: | |
| fps = float((row or {}).get("fps") or 10.0) | |
| return round(float(frame_index or 0) / fps, 3) if fps else 0.0 | |
| def set_start(row: dict[str, Any], frame_index: int, end_sec: float, label: str) -> tuple[float, float, str]: | |
| return time_from_frame(row, frame_index), float(end_sec or 0.0), label | |
| def set_end(row: dict[str, Any], frame_index: int, start_sec: float, label: str) -> tuple[float, float, str]: | |
| return float(start_sec or 0.0), time_from_frame(row, frame_index), label | |
| def load_segment_to_editor(table: Any, segment_index: int) -> tuple[float, float, str, str]: | |
| rows = normalize_table(table) | |
| idx = int(segment_index or 0) | |
| if idx < 0 or idx >= len(rows): | |
| return 0.0, 1.0, "", f"Segment {idx} is out of range." | |
| row = rows[idx] | |
| return float(row[1]), float(row[2]), str(row[4]), f"Loaded segment {idx}." | |
| def upsert_segment(row_state: dict[str, Any], table: Any, segment_index: int, start_sec: float, end_sec: float, label: str) -> tuple[Any, str, str]: | |
| rows = normalize_table(table) | |
| idx = int(segment_index if segment_index is not None else len(rows)) | |
| label = str(label or "").strip() | |
| try: | |
| start = float(start_sec) | |
| end = float(end_sec) | |
| except (TypeError, ValueError): | |
| return rows, timeline_for_state(row_state, rows), "Start/end must be numbers." | |
| if not label: | |
| return rows, timeline_for_state(row_state, rows), "Subtask label is required." | |
| if end <= start: | |
| return rows, timeline_for_state(row_state, rows), "end_sec must be greater than start_sec." | |
| new_row = [idx, start, end, round(end - start, 3), label] | |
| if 0 <= idx < len(rows): | |
| rows[idx] = new_row | |
| message = f"Updated segment {idx}." | |
| else: | |
| rows.append(new_row) | |
| message = "Added segment." | |
| rows = table_from_segments(segments_from_table(rows)) | |
| return rows, timeline_for_state(row_state, rows), message | |
| def delete_segment(row_state: dict[str, Any], table: Any, segment_index: int) -> tuple[Any, str, str]: | |
| rows = normalize_table(table) | |
| idx = int(segment_index or 0) | |
| if not (0 <= idx < len(rows)): | |
| return rows, timeline_for_state(row_state, rows), f"Segment {idx} is out of range." | |
| del rows[idx] | |
| rows = table_from_segments(segments_from_table(rows)) | |
| return rows, timeline_for_state(row_state, rows), f"Deleted segment {idx}." | |
| def sort_segments(row_state: dict[str, Any], table: Any) -> tuple[Any, str, str]: | |
| rows = table_from_segments(segments_from_table(table)) | |
| return rows, timeline_for_state(row_state, rows), "Sorted segments." | |
| def timeline_for_state(row_state: dict[str, Any], table: Any) -> str: | |
| if not row_state: | |
| return "" | |
| fps = float(row_state.get("fps") or 10.0) | |
| duration = int(row_state.get("num_frames") or 1) / fps if fps else 1.0 | |
| return render_timeline(table, duration) | |
| def save_episode(row_state: dict[str, Any], table: Any, annotator: str, notes: str) -> tuple[str, str]: | |
| if not row_state: | |
| return "No episode loaded.", str(ANNOTATION_PATH) | |
| segments = [segment.to_dict() for segment in segments_from_table(table)] | |
| payload = { | |
| "saved_at_unix": time.time(), | |
| "annotator": str(annotator or "").strip(), | |
| "bench_id": row_state.get("bench_id"), | |
| "source_dataset": row_state.get("source_dataset"), | |
| "instruction": row_state.get("instruction"), | |
| "old_subtasks": [segment.to_dict() for segment in gold_segments(row_state)], | |
| "new_subtasks": segments, | |
| "notes": str(notes or "").strip(), | |
| } | |
| ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with ANNOTATION_PATH.open("a", encoding="utf-8") as handle: | |
| handle.write(json.dumps(payload, sort_keys=True) + "\n") | |
| return f"Saved {row_state.get('bench_id')} with {len(segments)} segments.", str(ANNOTATION_PATH) | |
| def download_annotations() -> str: | |
| ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| ANNOTATION_PATH.touch(exist_ok=True) | |
| return str(ANNOTATION_PATH) | |
| with gr.Blocks(title="Wasup Gold Annotator") as demo: | |
| gr.Markdown("# Wasup Gold Annotator") | |
| gr.Markdown("Edit temporal subtask segments. Existing gold is loaded as the starting point; saves append revised annotations to JSONL.") | |
| row_state = gr.State({}) | |
| minus_one = gr.State(-1) | |
| plus_one = gr.State(1) | |
| choices_box = gr.Dropdown(label="Episode", choices=episode_choices(), value=episode_choices()[0] if episode_choices() else None) | |
| with gr.Row(): | |
| prev_episode = gr.Button("Previous episode") | |
| load_episode_btn = gr.Button("Load episode", variant="primary") | |
| next_episode = gr.Button("Next episode") | |
| with gr.Row(): | |
| video = gr.Video(label="Episode video") | |
| with gr.Column(): | |
| summary = gr.Code(label="Episode metadata", language="json") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| timeline = gr.HTML(label="Timeline") | |
| segments = gr.Dataframe( | |
| label="Editable segments", | |
| headers=["idx", "start_sec", "end_sec", "duration", "subtask"], | |
| datatype=["number", "number", "number", "number", "str"], | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| frame_prev = gr.Button("Previous frame") | |
| frame_next = gr.Button("Next frame") | |
| frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, value=0, step=1) | |
| with gr.Row(): | |
| frame_image = gr.Image(label="Selected frame", interactive=False) | |
| active = gr.JSON(label="Current segment") | |
| with gr.Row(): | |
| segment_index = gr.Number(label="Segment idx", value=0, precision=0) | |
| start_sec = gr.Number(label="start_sec", value=0.0) | |
| end_sec = gr.Number(label="end_sec", value=1.0) | |
| subtask = gr.Textbox(label="subtask") | |
| with gr.Row(): | |
| load_segment = gr.Button("Load selected segment") | |
| set_start_btn = gr.Button("Set start = current frame") | |
| set_end_btn = gr.Button("Set end = current frame") | |
| with gr.Row(): | |
| upsert_btn = gr.Button("Add/update segment", variant="primary") | |
| delete_btn = gr.Button("Delete segment") | |
| sort_btn = gr.Button("Sort segments") | |
| with gr.Row(): | |
| annotator = gr.Textbox(label="Annotator") | |
| notes = gr.Textbox(label="Notes") | |
| with gr.Row(): | |
| save_btn = gr.Button("Save episode annotation", variant="primary") | |
| download_btn = gr.Button("Prepare JSONL download") | |
| download_file = gr.File(label="Annotations JSONL") | |
| episode_outputs = [ | |
| row_state, | |
| video, | |
| summary, | |
| segments, | |
| timeline, | |
| frame_slider, | |
| frame_image, | |
| active, | |
| segment_index, | |
| start_sec, | |
| end_sec, | |
| subtask, | |
| status, | |
| ] | |
| load_episode_btn.click(load_episode, inputs=choices_box, outputs=episode_outputs) | |
| choices_box.change(load_episode, inputs=choices_box, outputs=episode_outputs) | |
| prev_episode.click(step_episode, inputs=[choices_box, minus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs) | |
| next_episode.click(step_episode, inputs=[choices_box, plus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs) | |
| frame_slider.change(refresh_frame, inputs=[row_state, segments, frame_slider], outputs=[frame_image, active]) | |
| frame_prev.click(step_frame, inputs=[row_state, segments, frame_slider, minus_one], outputs=[frame_slider, frame_image, active]) | |
| frame_next.click(step_frame, inputs=[row_state, segments, frame_slider, plus_one], outputs=[frame_slider, frame_image, active]) | |
| load_segment.click(load_segment_to_editor, inputs=[segments, segment_index], outputs=[start_sec, end_sec, subtask, status]) | |
| set_start_btn.click(set_start, inputs=[row_state, frame_slider, end_sec, subtask], outputs=[start_sec, end_sec, subtask]) | |
| set_end_btn.click(set_end, inputs=[row_state, frame_slider, start_sec, subtask], outputs=[start_sec, end_sec, subtask]) | |
| upsert_btn.click(upsert_segment, inputs=[row_state, segments, segment_index, start_sec, end_sec, subtask], outputs=[segments, timeline, status]) | |
| delete_btn.click(delete_segment, inputs=[row_state, segments, segment_index], outputs=[segments, timeline, status]) | |
| sort_btn.click(sort_segments, inputs=[row_state, segments], outputs=[segments, timeline, status]) | |
| save_btn.click(save_episode, inputs=[row_state, segments, annotator, notes], outputs=[status, download_file]) | |
| download_btn.click(download_annotations, outputs=download_file) | |
| demo.load(load_episode, inputs=choices_box, outputs=episode_outputs) | |
| if __name__ == "__main__": | |
| demo.launch() | |