from __future__ import annotations
import json
import os
import tempfile
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
import cv2
import fsspec
import gradio as gr
import numpy as np
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download
GOLD_REPO = os.environ.get("WASUP_GOLD_REPO", "macrodata/whats_going_on_bench")
GOLD_PARQUET = os.environ.get("WASUP_GOLD_PARQUET", "whats_going_on_bench.parquet")
GOLD_INDEX = os.environ.get("WASUP_GOLD_INDEX", "whats_going_on_bench_index.parquet")
GOLD_PARQUET_URL = f"https://huggingface.co/datasets/{GOLD_REPO}/resolve/main/{GOLD_PARQUET}"
CACHE_DIR = Path(tempfile.gettempdir()) / "wasup_gold_annotator"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
ANNOTATION_PATH = Path(
os.environ.get(
"WASUP_ANNOTATION_PATH",
"/data/wasup_gold_annotations.jsonl"
if Path("/data").exists()
else str(CACHE_DIR / "wasup_gold_annotations.jsonl"),
)
)
@dataclass(frozen=True)
class Segment:
start_sec: float
end_sec: float
label: str
def valid(self) -> bool:
return bool(self.label.strip()) and self.end_sec > self.start_sec
def to_dict(self) -> dict[str, Any]:
return {
"start_sec": round(float(self.start_sec), 3),
"end_sec": round(float(self.end_sec), 3),
"subtask": self.label.strip(),
}
@lru_cache(maxsize=1)
def load_gold_index() -> list[dict[str, Any]]:
path = hf_hub_download(repo_id=GOLD_REPO, repo_type="dataset", filename=GOLD_INDEX)
rows = pq.read_table(path).to_pylist()
return sorted(rows, key=lambda row: (str(row.get("source_dataset") or ""), str(row.get("bench_id") or "")))
@lru_cache(maxsize=1)
def gold_parquet() -> pq.ParquetFile:
fs = fsspec.filesystem("https", block_size=2**20)
return pq.ParquetFile(fs.open(GOLD_PARQUET_URL, "rb"))
@lru_cache(maxsize=512)
def load_gold_row(bench_id: str) -> dict[str, Any]:
index = {row["bench_id"]: row for row in load_gold_index()}
if bench_id not in index:
raise KeyError(f"{bench_id} is not in {GOLD_REPO}")
row_group = int(index[bench_id]["row_group"])
return gold_parquet().read_row_group(row_group).to_pylist()[0]
def episode_choices() -> list[str]:
choices = []
for idx, row in enumerate(load_gold_index()):
bench_id = row.get("bench_id") or ""
source = row.get("source_dataset") or bench_id.split("_", 1)[0]
instruction = row.get("instruction") or ""
choices.append(f"{idx:04d} | {source} | {bench_id} | {instruction}")
return choices
def bench_id_from_choice(choice: str | None) -> str:
if not choice:
return str(load_gold_index()[0]["bench_id"])
parts = choice.split("|")
if len(parts) >= 3:
return parts[2].strip()
return str(load_gold_index()[0]["bench_id"])
def step_episode(choice: str | None, delta: int) -> str:
items = episode_choices()
if not items:
return ""
if choice in items:
idx = items.index(choice)
else:
idx = 0
return items[max(0, min(idx + delta, len(items) - 1))]
def write_video(row: dict[str, Any]) -> str:
digest = str(row.get("primary_video_sha256") or row["bench_id"])
path = CACHE_DIR / f'{row["bench_id"]}_{digest[:12]}.mp4'
if not path.exists():
path.write_bytes(bytes(row["primary_video"]))
return str(path)
def segment_from_mapping(item: dict[str, Any], fps: float | None = None) -> Segment | None:
label = str(item.get("subtask") or item.get("label") or item.get("text") or "").strip()
try:
if "start_sec" in item and "end_sec" in item:
return Segment(float(item["start_sec"]), float(item["end_sec"]), label)
if fps and "start_frame" in item and "end_frame" in item:
return Segment(float(item["start_frame"]) / fps, float(item["end_frame"]) / fps, label)
except (TypeError, ValueError):
return None
return None
def collapse_frame_annotations(frames: list[dict[str, Any]], fps: float) -> list[Segment]:
if not frames:
return []
ordered = sorted(frames, key=lambda item: int(item.get("frame_index") or 0))
segments: list[Segment] = []
start = int(ordered[0].get("frame_index") or 0)
previous = start
label = str(ordered[0].get("subtask") or "")
for item in ordered[1:]:
frame = int(item.get("frame_index") or 0)
next_label = str(item.get("subtask") or "")
if next_label != label:
segment = Segment(start / fps, (previous + 1) / fps, label)
if segment.valid():
segments.append(segment)
start = frame
label = next_label
previous = frame
final = Segment(start / fps, (previous + 1) / fps, label)
if final.valid():
segments.append(final)
return segments
def gold_segments(row: dict[str, Any]) -> list[Segment]:
fps = float(row.get("fps") or 10.0)
explicit = []
for item in row.get("subtasks") or []:
segment = segment_from_mapping(item, fps=fps)
if segment and segment.valid():
explicit.append(segment)
return explicit or collapse_frame_annotations(row.get("frame_annotations") or [], fps=fps)
def table_from_segments(segments: list[Segment]) -> list[list[Any]]:
return [
[idx, round(seg.start_sec, 3), round(seg.end_sec, 3), round(seg.end_sec - seg.start_sec, 3), seg.label]
for idx, seg in enumerate(sorted(segments, key=lambda item: (item.start_sec, item.end_sec)))
]
def normalize_table(table: Any) -> list[list[Any]]:
if table is None:
return []
if hasattr(table, "values"):
rows = table.values.tolist()
elif isinstance(table, dict) and "data" in table:
rows = table["data"]
else:
rows = table
normalized = []
for idx, row in enumerate(rows or []):
if not row or len(row) < 5:
continue
try:
start = float(row[1])
end = float(row[2])
except (TypeError, ValueError):
continue
label = str(row[4] or "").strip()
if end > start and label:
normalized.append([idx, start, end, round(end - start, 3), label])
return normalized
def segments_from_table(table: Any) -> list[Segment]:
return [
Segment(float(row[1]), float(row[2]), str(row[4]))
for row in normalize_table(table)
]
def render_timeline(table: Any, duration_sec: float) -> str:
rows = normalize_table(table)
duration = max(float(duration_sec or 1.0), 1.0)
colors = ["#bfdbfe", "#bbf7d0", "#fde68a", "#fecaca", "#ddd6fe", "#bae6fd"]
parts = [
"
",
"
",
]
for idx, row in enumerate(rows):
start = max(0.0, min(float(row[1]), duration))
end = max(start, min(float(row[2]), duration))
left = 100 * start / duration
width = max(0.8, 100 * (end - start) / duration)
label = html_escape(str(row[4]))
parts.append(
f"
{idx}: {label}
"
)
parts.append("
")
return "".join(parts)
def html_escape(value: str) -> str:
return (
value.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)
def render_frame(video_path: str, frame_index: int) -> np.ndarray | None:
capture = cv2.VideoCapture(video_path)
capture.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
ok, frame = capture.read()
capture.release()
if not ok or frame is None:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def active_segment(table: Any, time_sec: float) -> dict[str, Any]:
for row in normalize_table(table):
if float(row[1]) <= time_sec < float(row[2]):
return {
"segment_index": int(row[0]),
"start_sec": float(row[1]),
"end_sec": float(row[2]),
"subtask": str(row[4]),
}
return {"segment_index": None, "time_sec": time_sec, "subtask": ""}
def load_episode(choice: str | None) -> tuple[Any, ...]:
bench_id = bench_id_from_choice(choice)
row = load_gold_row(bench_id)
video_path = write_video(row)
fps = float(row.get("fps") or 10.0)
num_frames = int(row.get("num_frames") or 1)
duration = num_frames / fps if fps else float(row.get("duration_sec") or 1.0)
table = table_from_segments(gold_segments(row))
frame = 0
summary = {
"bench_id": bench_id,
"source_dataset": row.get("source_dataset"),
"instruction": row.get("instruction"),
"fps": fps,
"num_frames": num_frames,
"duration_sec": duration,
"old_segments": len(table),
}
return (
row,
video_path,
json.dumps(summary, indent=2),
table,
render_timeline(table, duration),
gr.Slider(value=frame, minimum=0, maximum=max(0, num_frames - 1), step=1),
render_frame(video_path, frame),
active_segment(table, 0.0),
0,
table[0][1] if table else 0.0,
table[0][2] if table else 1.0,
table[0][4] if table else "",
f"Loaded {bench_id}",
)
def refresh_frame(row: dict[str, Any], table: Any, frame_index: int) -> tuple[Any, dict[str, Any]]:
if not row:
return None, {}
video_path = write_video(row)
fps = float(row.get("fps") or 10.0)
frame_index = max(0, min(int(frame_index or 0), int(row.get("num_frames") or 1) - 1))
time_sec = frame_index / fps if fps else 0.0
return render_frame(video_path, frame_index), active_segment(table, time_sec)
def step_frame(row: dict[str, Any], table: Any, frame_index: int, delta: int) -> tuple[int, Any, dict[str, Any]]:
if not row:
return 0, None, {}
max_frame = max(0, int(row.get("num_frames") or 1) - 1)
next_frame = max(0, min(int(frame_index or 0) + delta, max_frame))
image, active = refresh_frame(row, table, next_frame)
return next_frame, image, active
def time_from_frame(row: dict[str, Any], frame_index: int) -> float:
fps = float((row or {}).get("fps") or 10.0)
return round(float(frame_index or 0) / fps, 3) if fps else 0.0
def set_start(row: dict[str, Any], frame_index: int, end_sec: float, label: str) -> tuple[float, float, str]:
return time_from_frame(row, frame_index), float(end_sec or 0.0), label
def set_end(row: dict[str, Any], frame_index: int, start_sec: float, label: str) -> tuple[float, float, str]:
return float(start_sec or 0.0), time_from_frame(row, frame_index), label
def load_segment_to_editor(table: Any, segment_index: int) -> tuple[float, float, str, str]:
rows = normalize_table(table)
idx = int(segment_index or 0)
if idx < 0 or idx >= len(rows):
return 0.0, 1.0, "", f"Segment {idx} is out of range."
row = rows[idx]
return float(row[1]), float(row[2]), str(row[4]), f"Loaded segment {idx}."
def upsert_segment(row_state: dict[str, Any], table: Any, segment_index: int, start_sec: float, end_sec: float, label: str) -> tuple[Any, str, str]:
rows = normalize_table(table)
idx = int(segment_index if segment_index is not None else len(rows))
label = str(label or "").strip()
try:
start = float(start_sec)
end = float(end_sec)
except (TypeError, ValueError):
return rows, timeline_for_state(row_state, rows), "Start/end must be numbers."
if not label:
return rows, timeline_for_state(row_state, rows), "Subtask label is required."
if end <= start:
return rows, timeline_for_state(row_state, rows), "end_sec must be greater than start_sec."
new_row = [idx, start, end, round(end - start, 3), label]
if 0 <= idx < len(rows):
rows[idx] = new_row
message = f"Updated segment {idx}."
else:
rows.append(new_row)
message = "Added segment."
rows = table_from_segments(segments_from_table(rows))
return rows, timeline_for_state(row_state, rows), message
def delete_segment(row_state: dict[str, Any], table: Any, segment_index: int) -> tuple[Any, str, str]:
rows = normalize_table(table)
idx = int(segment_index or 0)
if not (0 <= idx < len(rows)):
return rows, timeline_for_state(row_state, rows), f"Segment {idx} is out of range."
del rows[idx]
rows = table_from_segments(segments_from_table(rows))
return rows, timeline_for_state(row_state, rows), f"Deleted segment {idx}."
def sort_segments(row_state: dict[str, Any], table: Any) -> tuple[Any, str, str]:
rows = table_from_segments(segments_from_table(table))
return rows, timeline_for_state(row_state, rows), "Sorted segments."
def timeline_for_state(row_state: dict[str, Any], table: Any) -> str:
if not row_state:
return ""
fps = float(row_state.get("fps") or 10.0)
duration = int(row_state.get("num_frames") or 1) / fps if fps else 1.0
return render_timeline(table, duration)
def save_episode(row_state: dict[str, Any], table: Any, annotator: str, notes: str) -> tuple[str, str]:
if not row_state:
return "No episode loaded.", str(ANNOTATION_PATH)
segments = [segment.to_dict() for segment in segments_from_table(table)]
payload = {
"saved_at_unix": time.time(),
"annotator": str(annotator or "").strip(),
"bench_id": row_state.get("bench_id"),
"source_dataset": row_state.get("source_dataset"),
"instruction": row_state.get("instruction"),
"old_subtasks": [segment.to_dict() for segment in gold_segments(row_state)],
"new_subtasks": segments,
"notes": str(notes or "").strip(),
}
ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True)
with ANNOTATION_PATH.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(payload, sort_keys=True) + "\n")
return f"Saved {row_state.get('bench_id')} with {len(segments)} segments.", str(ANNOTATION_PATH)
def download_annotations() -> str:
ANNOTATION_PATH.parent.mkdir(parents=True, exist_ok=True)
ANNOTATION_PATH.touch(exist_ok=True)
return str(ANNOTATION_PATH)
with gr.Blocks(title="Wasup Gold Annotator") as demo:
gr.Markdown("# Wasup Gold Annotator")
gr.Markdown("Edit temporal subtask segments. Existing gold is loaded as the starting point; saves append revised annotations to JSONL.")
row_state = gr.State({})
minus_one = gr.State(-1)
plus_one = gr.State(1)
choices_box = gr.Dropdown(label="Episode", choices=episode_choices(), value=episode_choices()[0] if episode_choices() else None)
with gr.Row():
prev_episode = gr.Button("Previous episode")
load_episode_btn = gr.Button("Load episode", variant="primary")
next_episode = gr.Button("Next episode")
with gr.Row():
video = gr.Video(label="Episode video")
with gr.Column():
summary = gr.Code(label="Episode metadata", language="json")
status = gr.Textbox(label="Status", interactive=False)
timeline = gr.HTML(label="Timeline")
segments = gr.Dataframe(
label="Editable segments",
headers=["idx", "start_sec", "end_sec", "duration", "subtask"],
datatype=["number", "number", "number", "number", "str"],
interactive=True,
)
with gr.Row():
frame_prev = gr.Button("Previous frame")
frame_next = gr.Button("Next frame")
frame_slider = gr.Slider(label="Frame", minimum=0, maximum=1, value=0, step=1)
with gr.Row():
frame_image = gr.Image(label="Selected frame", interactive=False)
active = gr.JSON(label="Current segment")
with gr.Row():
segment_index = gr.Number(label="Segment idx", value=0, precision=0)
start_sec = gr.Number(label="start_sec", value=0.0)
end_sec = gr.Number(label="end_sec", value=1.0)
subtask = gr.Textbox(label="subtask")
with gr.Row():
load_segment = gr.Button("Load selected segment")
set_start_btn = gr.Button("Set start = current frame")
set_end_btn = gr.Button("Set end = current frame")
with gr.Row():
upsert_btn = gr.Button("Add/update segment", variant="primary")
delete_btn = gr.Button("Delete segment")
sort_btn = gr.Button("Sort segments")
with gr.Row():
annotator = gr.Textbox(label="Annotator")
notes = gr.Textbox(label="Notes")
with gr.Row():
save_btn = gr.Button("Save episode annotation", variant="primary")
download_btn = gr.Button("Prepare JSONL download")
download_file = gr.File(label="Annotations JSONL")
episode_outputs = [
row_state,
video,
summary,
segments,
timeline,
frame_slider,
frame_image,
active,
segment_index,
start_sec,
end_sec,
subtask,
status,
]
load_episode_btn.click(load_episode, inputs=choices_box, outputs=episode_outputs)
choices_box.change(load_episode, inputs=choices_box, outputs=episode_outputs)
prev_episode.click(step_episode, inputs=[choices_box, minus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs)
next_episode.click(step_episode, inputs=[choices_box, plus_one], outputs=choices_box).then(load_episode, inputs=choices_box, outputs=episode_outputs)
frame_slider.change(refresh_frame, inputs=[row_state, segments, frame_slider], outputs=[frame_image, active])
frame_prev.click(step_frame, inputs=[row_state, segments, frame_slider, minus_one], outputs=[frame_slider, frame_image, active])
frame_next.click(step_frame, inputs=[row_state, segments, frame_slider, plus_one], outputs=[frame_slider, frame_image, active])
load_segment.click(load_segment_to_editor, inputs=[segments, segment_index], outputs=[start_sec, end_sec, subtask, status])
set_start_btn.click(set_start, inputs=[row_state, frame_slider, end_sec, subtask], outputs=[start_sec, end_sec, subtask])
set_end_btn.click(set_end, inputs=[row_state, frame_slider, start_sec, subtask], outputs=[start_sec, end_sec, subtask])
upsert_btn.click(upsert_segment, inputs=[row_state, segments, segment_index, start_sec, end_sec, subtask], outputs=[segments, timeline, status])
delete_btn.click(delete_segment, inputs=[row_state, segments, segment_index], outputs=[segments, timeline, status])
sort_btn.click(sort_segments, inputs=[row_state, segments], outputs=[segments, timeline, status])
save_btn.click(save_episode, inputs=[row_state, segments, annotator, notes], outputs=[status, download_file])
download_btn.click(download_annotations, outputs=download_file)
demo.load(load_episode, inputs=choices_box, outputs=episode_outputs)
if __name__ == "__main__":
demo.launch()