import os import threading import queue import time import tempfile import shutil from contextlib import redirect_stdout, redirect_stderr from typing import List import gradio as gr from huggingface_hub import HfApi from delete_episodes import ( download_dataset, list_episodes, delete_episodes_and_repair, upload_dataset, ) class _StreamToQueue: def __init__(self, q: "queue.Queue[str]"): self.q = q self._buffer = "" def write(self, s: str): if not isinstance(s, str): s = str(s) self._buffer += s while "\n" in self._buffer: line, self._buffer = self._buffer.split("\n", 1) self.q.put(line + "\n") def flush(self): if self._buffer: self.q.put(self._buffer) self._buffer = "" def search_datasets_fn(query: str) -> List[str]: """Search for datasets on HuggingFace""" api = HfApi() try: items = api.list_datasets(search=(query or "").strip() or None) repo_ids = [getattr(d, "id", None) or getattr(d, "repo_id", None) for d in items] repo_ids = [r for r in repo_ids if r] # Remove duplicates while preserving order seen = set() unique = [] for r in repo_ids: if r not in seen: unique.append(r) seen.add(r) return unique[:500] except Exception as e: print(f"Error searching datasets: {e}") return [] def load_episodes_for_dataset(repo_id: str, progress=gr.Progress()): """Download dataset and list available episodes""" if not repo_id: return "" token = os.environ.get("HF_TOKEN") temp_dir = tempfile.mkdtemp(prefix="episode_delete_") try: progress(0, desc="Downloading dataset...") download_dataset(repo_id, temp_dir, hf_token=token) progress(0.7, desc="Listing episodes...") episodes = list_episodes(temp_dir) # Cleanup temp directory shutil.rmtree(temp_dir, ignore_errors=True) if not episodes: return "No episodes found" # Return info about available episodes return f"Found {len(episodes)} episodes: {', '.join(map(str, episodes))}" except Exception as e: import traceback error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) # Cleanup on error try: if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) except Exception: pass return error_msg def delete_episodes_stream(repo_id: str, episode_indexes_str: str, dest_repo_id: str): """Delete selected episodes and upload to destination repo""" if not repo_id: yield "Please provide a source dataset repo ID." return if not episode_indexes_str or not episode_indexes_str.strip(): yield "Please provide at least one episode index to delete." return if not dest_repo_id or not dest_repo_id.strip(): yield "Please provide a destination repo ID." return # Parse comma-separated episode indexes episode_indexes = [] for ep_str in episode_indexes_str.split(","): try: ep_num = int(ep_str.strip()) episode_indexes.append(ep_num) except ValueError: yield f"Invalid episode index: {ep_str.strip()}" return token = os.environ.get("HF_TOKEN") q: "queue.Queue[str]" = queue.Queue() done = {"ok": False, "msg": ""} def _worker(): stream = _StreamToQueue(q) temp_dir = tempfile.mkdtemp(prefix="episode_delete_") try: with redirect_stdout(stream), redirect_stderr(stream): print("Downloading dataset...", flush=True) download_dataset(repo_id, temp_dir, hf_token=token) print(f"\nDeleting episodes: {episode_indexes}", flush=True) delete_episodes_and_repair( dataset_path=temp_dir, episode_indexes=episode_indexes, run_stats=False, # Skip stats for now as script may not be available ) print(f"\nUploading to {dest_repo_id}...", flush=True) upload_dataset( local_dir=temp_dir, dest_repo_id=dest_repo_id, hf_token=token, commit_message=f"Deleted episodes: {episode_indexes}", private=False, ) print("\nUpload complete!", flush=True) done["ok"] = True done["msg"] = f"Successfully deleted {len(episode_indexes)} episodes and uploaded to {dest_repo_id}" except Exception as e: print(f"\nError: {e}", flush=True) done["ok"] = False done["msg"] = f"Error: {e}" finally: # Cleanup try: if os.path.isdir(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) print(f"\nCleaned up temp directory: {temp_dir}", flush=True) except Exception: pass try: stream.flush() except Exception: pass t = threading.Thread(target=_worker, daemon=True) t.start() buffer = "" yield "Starting process...\n" while t.is_alive() or not q.empty(): try: line = q.get(timeout=0.1) buffer += line if len(buffer) > 0: yield buffer except queue.Empty: pass time.sleep(0.05) # Final status if done["msg"]: buffer += ("\n" if not buffer.endswith("\n") else "") + "=" * 50 + "\n" + done["msg"] yield buffer # Build the Gradio interface with gr.Blocks(title="LeRobot Episode Deleter") as demo: gr.Markdown("**Delete specific episodes from a Hugging Face dataset (LeRobot format).**") # Load initial datasets _initial_choices = search_datasets_fn("griffinlabs-cortex") with gr.Row(): org_input = gr.Textbox( label="Organization or keyword", value="griffinlabs-cortex", placeholder="e.g., lerobot, griffinlabs-cortex" ) load_btn = gr.Button("Load Datasets") dataset_dropdown = gr.Dropdown( label="Select dataset", choices=_initial_choices, interactive=True, ) episodes_info = gr.Textbox( label="Available episodes", interactive=False, lines=2 ) episodes_input = gr.Textbox( label="Episode indexes to delete (comma-separated)", placeholder="0, 1, 2" ) dest_repo_input = gr.Textbox( label="Destination repo id (required)", placeholder="org/cleaned_dataset" ) execute_btn = gr.Button("Delete Episodes and Upload") progress_log = gr.Textbox(label="Progress log", lines=20) # Event handlers def load_datasets_from_org(org_name): results = search_datasets_fn(org_name) return gr.update(choices=results, value=None) load_btn.click( load_datasets_from_org, inputs=org_input, outputs=dataset_dropdown, ) dataset_dropdown.change( load_episodes_for_dataset, inputs=dataset_dropdown, outputs=episodes_info, ) execute_btn.click( delete_episodes_stream, inputs=[dataset_dropdown, episodes_input, dest_repo_input], outputs=progress_log, ) if __name__ == "__main__": demo.launch()