|
|
import os |
|
|
import threading |
|
|
import queue |
|
|
import time |
|
|
import tempfile |
|
|
import shutil |
|
|
from contextlib import redirect_stdout, redirect_stderr |
|
|
from typing import List |
|
|
|
|
|
import gradio as gr |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
from delete_episodes import ( |
|
|
download_dataset, |
|
|
list_episodes, |
|
|
delete_episodes_and_repair, |
|
|
upload_dataset, |
|
|
) |
|
|
|
|
|
|
|
|
class _StreamToQueue: |
|
|
def __init__(self, q: "queue.Queue[str]"): |
|
|
self.q = q |
|
|
self._buffer = "" |
|
|
|
|
|
def write(self, s: str): |
|
|
if not isinstance(s, str): |
|
|
s = str(s) |
|
|
self._buffer += s |
|
|
while "\n" in self._buffer: |
|
|
line, self._buffer = self._buffer.split("\n", 1) |
|
|
self.q.put(line + "\n") |
|
|
|
|
|
def flush(self): |
|
|
if self._buffer: |
|
|
self.q.put(self._buffer) |
|
|
self._buffer = "" |
|
|
|
|
|
|
|
|
def search_datasets_fn(query: str) -> List[str]: |
|
|
"""Search for datasets on HuggingFace""" |
|
|
api = HfApi() |
|
|
try: |
|
|
items = api.list_datasets(search=(query or "").strip() or None) |
|
|
repo_ids = [getattr(d, "id", None) or getattr(d, "repo_id", None) for d in items] |
|
|
repo_ids = [r for r in repo_ids if r] |
|
|
|
|
|
seen = set() |
|
|
unique = [] |
|
|
for r in repo_ids: |
|
|
if r not in seen: |
|
|
unique.append(r) |
|
|
seen.add(r) |
|
|
return unique[:500] |
|
|
except Exception as e: |
|
|
print(f"Error searching datasets: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def load_episodes_for_dataset(repo_id: str, progress=gr.Progress()): |
|
|
"""Download dataset and list available episodes""" |
|
|
if not repo_id: |
|
|
return "" |
|
|
|
|
|
token = os.environ.get("HF_TOKEN") |
|
|
temp_dir = tempfile.mkdtemp(prefix="episode_delete_") |
|
|
|
|
|
try: |
|
|
progress(0, desc="Downloading dataset...") |
|
|
download_dataset(repo_id, temp_dir, hf_token=token) |
|
|
|
|
|
progress(0.7, desc="Listing episodes...") |
|
|
episodes = list_episodes(temp_dir) |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
if not episodes: |
|
|
return "No episodes found" |
|
|
|
|
|
|
|
|
return f"Found {len(episodes)} episodes: {', '.join(map(str, episodes))}" |
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
|
|
|
try: |
|
|
if temp_dir and os.path.exists(temp_dir): |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
except Exception: |
|
|
pass |
|
|
return error_msg |
|
|
|
|
|
|
|
|
def delete_episodes_stream(repo_id: str, episode_indexes_str: str, dest_repo_id: str): |
|
|
"""Delete selected episodes and upload to destination repo""" |
|
|
if not repo_id: |
|
|
yield "Please provide a source dataset repo ID." |
|
|
return |
|
|
|
|
|
if not episode_indexes_str or not episode_indexes_str.strip(): |
|
|
yield "Please provide at least one episode index to delete." |
|
|
return |
|
|
|
|
|
if not dest_repo_id or not dest_repo_id.strip(): |
|
|
yield "Please provide a destination repo ID." |
|
|
return |
|
|
|
|
|
|
|
|
episode_indexes = [] |
|
|
for ep_str in episode_indexes_str.split(","): |
|
|
try: |
|
|
ep_num = int(ep_str.strip()) |
|
|
episode_indexes.append(ep_num) |
|
|
except ValueError: |
|
|
yield f"Invalid episode index: {ep_str.strip()}" |
|
|
return |
|
|
|
|
|
token = os.environ.get("HF_TOKEN") |
|
|
q: "queue.Queue[str]" = queue.Queue() |
|
|
done = {"ok": False, "msg": ""} |
|
|
|
|
|
def _worker(): |
|
|
stream = _StreamToQueue(q) |
|
|
temp_dir = tempfile.mkdtemp(prefix="episode_delete_") |
|
|
|
|
|
try: |
|
|
with redirect_stdout(stream), redirect_stderr(stream): |
|
|
print("Downloading dataset...", flush=True) |
|
|
download_dataset(repo_id, temp_dir, hf_token=token) |
|
|
|
|
|
print(f"\nDeleting episodes: {episode_indexes}", flush=True) |
|
|
delete_episodes_and_repair( |
|
|
dataset_path=temp_dir, |
|
|
episode_indexes=episode_indexes, |
|
|
run_stats=False, |
|
|
) |
|
|
|
|
|
print(f"\nUploading to {dest_repo_id}...", flush=True) |
|
|
upload_dataset( |
|
|
local_dir=temp_dir, |
|
|
dest_repo_id=dest_repo_id, |
|
|
hf_token=token, |
|
|
commit_message=f"Deleted episodes: {episode_indexes}", |
|
|
private=False, |
|
|
) |
|
|
|
|
|
print("\nUpload complete!", flush=True) |
|
|
done["ok"] = True |
|
|
done["msg"] = f"Successfully deleted {len(episode_indexes)} episodes and uploaded to {dest_repo_id}" |
|
|
except Exception as e: |
|
|
print(f"\nError: {e}", flush=True) |
|
|
done["ok"] = False |
|
|
done["msg"] = f"Error: {e}" |
|
|
finally: |
|
|
|
|
|
try: |
|
|
if os.path.isdir(temp_dir): |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
print(f"\nCleaned up temp directory: {temp_dir}", flush=True) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
stream.flush() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
t = threading.Thread(target=_worker, daemon=True) |
|
|
t.start() |
|
|
|
|
|
buffer = "" |
|
|
yield "Starting process...\n" |
|
|
|
|
|
while t.is_alive() or not q.empty(): |
|
|
try: |
|
|
line = q.get(timeout=0.1) |
|
|
buffer += line |
|
|
if len(buffer) > 0: |
|
|
yield buffer |
|
|
except queue.Empty: |
|
|
pass |
|
|
time.sleep(0.05) |
|
|
|
|
|
|
|
|
if done["msg"]: |
|
|
buffer += ("\n" if not buffer.endswith("\n") else "") + "=" * 50 + "\n" + done["msg"] |
|
|
yield buffer |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="LeRobot Episode Deleter") as demo: |
|
|
gr.Markdown("**Delete specific episodes from a Hugging Face dataset (LeRobot format).**") |
|
|
|
|
|
|
|
|
_initial_choices = search_datasets_fn("griffinlabs-cortex") |
|
|
|
|
|
with gr.Row(): |
|
|
org_input = gr.Textbox( |
|
|
label="Organization or keyword", |
|
|
value="griffinlabs-cortex", |
|
|
placeholder="e.g., lerobot, griffinlabs-cortex" |
|
|
) |
|
|
load_btn = gr.Button("Load Datasets") |
|
|
|
|
|
dataset_dropdown = gr.Dropdown( |
|
|
label="Select dataset", |
|
|
choices=_initial_choices, |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
episodes_info = gr.Textbox( |
|
|
label="Available episodes", |
|
|
interactive=False, |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
episodes_input = gr.Textbox( |
|
|
label="Episode indexes to delete (comma-separated)", |
|
|
placeholder="0, 1, 2" |
|
|
) |
|
|
|
|
|
dest_repo_input = gr.Textbox( |
|
|
label="Destination repo id (required)", |
|
|
placeholder="org/cleaned_dataset" |
|
|
) |
|
|
|
|
|
execute_btn = gr.Button("Delete Episodes and Upload") |
|
|
|
|
|
progress_log = gr.Textbox(label="Progress log", lines=20) |
|
|
|
|
|
|
|
|
def load_datasets_from_org(org_name): |
|
|
results = search_datasets_fn(org_name) |
|
|
return gr.update(choices=results, value=None) |
|
|
|
|
|
load_btn.click( |
|
|
load_datasets_from_org, |
|
|
inputs=org_input, |
|
|
outputs=dataset_dropdown, |
|
|
) |
|
|
|
|
|
dataset_dropdown.change( |
|
|
load_episodes_for_dataset, |
|
|
inputs=dataset_dropdown, |
|
|
outputs=episodes_info, |
|
|
) |
|
|
|
|
|
execute_btn.click( |
|
|
delete_episodes_stream, |
|
|
inputs=[dataset_dropdown, episodes_input, dest_repo_input], |
|
|
outputs=progress_log, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|