SuveenE's picture
Add files
828363b
import os
import threading
import queue
import time
import tempfile
import shutil
from contextlib import redirect_stdout, redirect_stderr
from typing import List
import gradio as gr
from huggingface_hub import HfApi
from delete_episodes import (
download_dataset,
list_episodes,
delete_episodes_and_repair,
upload_dataset,
)
class _StreamToQueue:
def __init__(self, q: "queue.Queue[str]"):
self.q = q
self._buffer = ""
def write(self, s: str):
if not isinstance(s, str):
s = str(s)
self._buffer += s
while "\n" in self._buffer:
line, self._buffer = self._buffer.split("\n", 1)
self.q.put(line + "\n")
def flush(self):
if self._buffer:
self.q.put(self._buffer)
self._buffer = ""
def search_datasets_fn(query: str) -> List[str]:
"""Search for datasets on HuggingFace"""
api = HfApi()
try:
items = api.list_datasets(search=(query or "").strip() or None)
repo_ids = [getattr(d, "id", None) or getattr(d, "repo_id", None) for d in items]
repo_ids = [r for r in repo_ids if r]
# Remove duplicates while preserving order
seen = set()
unique = []
for r in repo_ids:
if r not in seen:
unique.append(r)
seen.add(r)
return unique[:500]
except Exception as e:
print(f"Error searching datasets: {e}")
return []
def load_episodes_for_dataset(repo_id: str, progress=gr.Progress()):
"""Download dataset and list available episodes"""
if not repo_id:
return ""
token = os.environ.get("HF_TOKEN")
temp_dir = tempfile.mkdtemp(prefix="episode_delete_")
try:
progress(0, desc="Downloading dataset...")
download_dataset(repo_id, temp_dir, hf_token=token)
progress(0.7, desc="Listing episodes...")
episodes = list_episodes(temp_dir)
# Cleanup temp directory
shutil.rmtree(temp_dir, ignore_errors=True)
if not episodes:
return "No episodes found"
# Return info about available episodes
return f"Found {len(episodes)} episodes: {', '.join(map(str, episodes))}"
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
# Cleanup on error
try:
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception:
pass
return error_msg
def delete_episodes_stream(repo_id: str, episode_indexes_str: str, dest_repo_id: str):
"""Delete selected episodes and upload to destination repo"""
if not repo_id:
yield "Please provide a source dataset repo ID."
return
if not episode_indexes_str or not episode_indexes_str.strip():
yield "Please provide at least one episode index to delete."
return
if not dest_repo_id or not dest_repo_id.strip():
yield "Please provide a destination repo ID."
return
# Parse comma-separated episode indexes
episode_indexes = []
for ep_str in episode_indexes_str.split(","):
try:
ep_num = int(ep_str.strip())
episode_indexes.append(ep_num)
except ValueError:
yield f"Invalid episode index: {ep_str.strip()}"
return
token = os.environ.get("HF_TOKEN")
q: "queue.Queue[str]" = queue.Queue()
done = {"ok": False, "msg": ""}
def _worker():
stream = _StreamToQueue(q)
temp_dir = tempfile.mkdtemp(prefix="episode_delete_")
try:
with redirect_stdout(stream), redirect_stderr(stream):
print("Downloading dataset...", flush=True)
download_dataset(repo_id, temp_dir, hf_token=token)
print(f"\nDeleting episodes: {episode_indexes}", flush=True)
delete_episodes_and_repair(
dataset_path=temp_dir,
episode_indexes=episode_indexes,
run_stats=False, # Skip stats for now as script may not be available
)
print(f"\nUploading to {dest_repo_id}...", flush=True)
upload_dataset(
local_dir=temp_dir,
dest_repo_id=dest_repo_id,
hf_token=token,
commit_message=f"Deleted episodes: {episode_indexes}",
private=False,
)
print("\nUpload complete!", flush=True)
done["ok"] = True
done["msg"] = f"Successfully deleted {len(episode_indexes)} episodes and uploaded to {dest_repo_id}"
except Exception as e:
print(f"\nError: {e}", flush=True)
done["ok"] = False
done["msg"] = f"Error: {e}"
finally:
# Cleanup
try:
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
print(f"\nCleaned up temp directory: {temp_dir}", flush=True)
except Exception:
pass
try:
stream.flush()
except Exception:
pass
t = threading.Thread(target=_worker, daemon=True)
t.start()
buffer = ""
yield "Starting process...\n"
while t.is_alive() or not q.empty():
try:
line = q.get(timeout=0.1)
buffer += line
if len(buffer) > 0:
yield buffer
except queue.Empty:
pass
time.sleep(0.05)
# Final status
if done["msg"]:
buffer += ("\n" if not buffer.endswith("\n") else "") + "=" * 50 + "\n" + done["msg"]
yield buffer
# Build the Gradio interface
with gr.Blocks(title="LeRobot Episode Deleter") as demo:
gr.Markdown("**Delete specific episodes from a Hugging Face dataset (LeRobot format).**")
# Load initial datasets
_initial_choices = search_datasets_fn("griffinlabs-cortex")
with gr.Row():
org_input = gr.Textbox(
label="Organization or keyword",
value="griffinlabs-cortex",
placeholder="e.g., lerobot, griffinlabs-cortex"
)
load_btn = gr.Button("Load Datasets")
dataset_dropdown = gr.Dropdown(
label="Select dataset",
choices=_initial_choices,
interactive=True,
)
episodes_info = gr.Textbox(
label="Available episodes",
interactive=False,
lines=2
)
episodes_input = gr.Textbox(
label="Episode indexes to delete (comma-separated)",
placeholder="0, 1, 2"
)
dest_repo_input = gr.Textbox(
label="Destination repo id (required)",
placeholder="org/cleaned_dataset"
)
execute_btn = gr.Button("Delete Episodes and Upload")
progress_log = gr.Textbox(label="Progress log", lines=20)
# Event handlers
def load_datasets_from_org(org_name):
results = search_datasets_fn(org_name)
return gr.update(choices=results, value=None)
load_btn.click(
load_datasets_from_org,
inputs=org_input,
outputs=dataset_dropdown,
)
dataset_dropdown.change(
load_episodes_for_dataset,
inputs=dataset_dropdown,
outputs=episodes_info,
)
execute_btn.click(
delete_episodes_stream,
inputs=[dataset_dropdown, episodes_input, dest_repo_input],
outputs=progress_log,
)
if __name__ == "__main__":
demo.launch()