import os import re import sys import glob import json import logging import shutil import subprocess from pathlib import Path from typing import List, Optional, Tuple from huggingface_hub import snapshot_download, upload_folder, create_repo import pandas as pd logger = logging.getLogger(__name__) if not logger.handlers: logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") def _enable_hf_transfer(): """Enable hf_transfer acceleration if the package is installed""" if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" logger.info("Enabled hf_transfer acceleration (HF_HUB_ENABLE_HF_TRANSFER=1)") def download_dataset( repo_id: str, local_dir: str, hf_token: Optional[str] = None, ) -> str: """Download a Hugging Face dataset by repo_id. Returns the local directory path. """ _enable_hf_transfer() local_path = Path(local_dir) local_path.mkdir(parents=True, exist_ok=True) logger.info(f"Downloading dataset '{repo_id}' to '{local_dir}' ...") path = snapshot_download( repo_id=repo_id, repo_type="dataset", token=hf_token, local_dir=str(local_dir), local_dir_use_symlinks=False, ) logger.info(f"Downloaded: {repo_id} -> {path}") return str(local_path) def check_v2_format(dataset_path: str) -> bool: """Check if dataset is in v2.x format""" info_path = os.path.join(dataset_path, "meta", "info.json") if not os.path.exists(info_path): raise ValueError(f"Error: {info_path} does not exist") with open(info_path, "r") as f: try: info = json.load(f) if "codebase_version" not in info: raise ValueError(f"Error: {info_path} is not a valid v2.x dataset") version = info["codebase_version"] # Accept any v2.x version (v2.0, v2.1, etc.) if not version.startswith("v2."): raise ValueError( f"Error: {info_path} is not a v2.x dataset, found {version}" ) logger.info(f"Dataset version: {version}") return True except json.JSONDecodeError: raise ValueError(f"Error: {info_path} is not a valid JSON file") def update_info_counts(dataset_path: str): """Update total_episodes and total_videos counts in info.json to reflect actual counts. Args: dataset_path: Path to the dataset """ info_path = os.path.join(dataset_path, "meta", "info.json") if not os.path.exists(info_path): raise ValueError(f"Error: {info_path} does not exist") logger.info("Updating info.json counts to reflect actual dataset state...") # Count actual episodes episodes = list_episodes(dataset_path) new_episode_count = len(episodes) # Count actual videos videos_folder = os.path.join(dataset_path, "videos", "chunk-000") video_count = 0 if os.path.exists(videos_folder): video_folders = [d for d in os.listdir(videos_folder) if os.path.isdir(os.path.join(videos_folder, d))] for folder in video_folders: video_files = glob.glob( os.path.join(videos_folder, folder, "episode_*.mp4") ) video_count += len(video_files) # Read and update info.json with open(info_path, "r") as f: info = json.load(f) old_episodes = info.get("total_episodes", 0) old_videos = info.get("total_videos", 0) info["total_episodes"] = new_episode_count info["total_videos"] = video_count with open(info_path, "w") as f: json.dump(info, f, indent=4) logger.info( f"Updated total_episodes: {old_episodes} → {new_episode_count}" ) logger.info( f"Updated total_videos: {old_videos} → {video_count}" ) def list_episodes(dataset_path: str) -> List[int]: """List all episode numbers in the dataset""" parquets_folder = os.path.join(dataset_path, "data", "chunk-000") if not os.path.exists(parquets_folder): return [] parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) episode_numbers = [] for file in parquet_files: match = re.search(r"episode_(\d+)\.parquet", file) if match: episode_numbers.append(int(match.group(1))) return sorted(episode_numbers) def delete_ds_store(dataset_path: str): """Delete all .DS_Store files in the given dataset path and its subdirectories""" logger.info("Deleting .DS_Store files...") ds_store_files = glob.glob( os.path.join(dataset_path, "**", ".DS_Store"), recursive=True ) if not ds_store_files: logger.info("No .DS_Store files found") return for file in ds_store_files: os.remove(file) logger.info(f"Deleted {file}") logger.info(".DS_Store files deleted") def update_meta_jsonl_files(dataset_path: str, indexes_to_delete: List[int]): """Update episodes.jsonl and episodes_stats.jsonl by removing deleted episodes and re-indexing""" meta_folder = os.path.join(dataset_path, "meta") episodes_file = os.path.join(meta_folder, "episodes.jsonl") episodes_stats_file = os.path.join(meta_folder, "episodes_stats.jsonl") # Process episodes.jsonl if os.path.exists(episodes_file): logger.info("Updating episodes.jsonl...") episodes = [] with open(episodes_file, "r") as f: for line in f: line = line.strip() if line: # Skip empty lines episode = json.loads(line) if episode["episode_index"] not in indexes_to_delete: episodes.append(episode) # Re-index episodes for new_index, episode in enumerate(episodes): episode["episode_index"] = new_index # Write back with open(episodes_file, "w") as f: for episode in episodes: f.write(json.dumps(episode) + "\n") logger.info(f"Updated episodes.jsonl: {len(episodes)} episodes remaining") else: logger.warning(f"episodes.jsonl not found at {episodes_file}") # Process episodes_stats.jsonl if os.path.exists(episodes_stats_file): logger.info("Updating episodes_stats.jsonl...") stats = [] with open(episodes_stats_file, "r") as f: for line in f: line = line.strip() if line: # Skip empty lines stat = json.loads(line) if stat["episode_index"] not in indexes_to_delete: stats.append(stat) # Re-index stats for new_index, stat in enumerate(stats): stat["episode_index"] = new_index # Write back with open(episodes_stats_file, "w") as f: for stat in stats: f.write(json.dumps(stat) + "\n") logger.info(f"Updated episodes_stats.jsonl: {len(stats)} episode stats remaining") else: logger.warning(f"episodes_stats.jsonl not found at {episodes_stats_file}") def delete_episode_files(dataset_path: str, indexes: List[int]): """Delete parquet and video files for specified episode indexes""" parquets_folder = os.path.join(dataset_path, "data", "chunk-000") videos_folder = os.path.join(dataset_path, "videos", "chunk-000") # Delete parquet files logger.info("Deleting parquet files...") parquet_files = glob.glob(os.path.join(parquets_folder, "*.parquet")) for index in indexes: for file in parquet_files: if f"episode_{index:06d}.parquet" in file: os.remove(file) logger.info(f"Deleted file {file}") # Delete video files logger.info("Deleting video files...") if os.path.exists(videos_folder): video_folders = os.listdir(videos_folder) for index in indexes: for folder in video_folders: video_files = glob.glob( os.path.join(videos_folder, folder, f"episode_{index:06d}.mp4") ) for video_file in video_files: os.remove(video_file) logger.info(f"Deleted file {video_file}") def process_parquet_files(dataset_path: str): """Process all parquet files by correcting the episode_index column""" parquets_folder = os.path.join(dataset_path, "data", "chunk-000") videos_folder = os.path.join(dataset_path, "videos", "chunk-000") logger.info("Processing parquet files...") parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) if not parquet_files: logger.info(f"No parquet files found in {parquets_folder}") return logger.info(f"Found {len(parquet_files)} parquet files to process") # Order files by episode number parquet_files.sort( key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1)) ) # Check if episode numbers are continuous episode_numbers = [ int(re.search(r"episode_(\d+)\.parquet", file).group(1)) for file in parquet_files ] episode_numbers.sort() # Get video folders if they exist video_folders = [] if os.path.exists(videos_folder): video_folders = os.listdir(videos_folder) if episode_numbers != list(range(len(episode_numbers))): logger.info( "Episode numbers are not continuous or starting from 0. Renaming files and videos..." ) for i, file in enumerate(parquet_files): new_episode_number = i new_file = os.path.join( parquets_folder, f"episode_{new_episode_number:06d}.parquet" ) os.rename(file, new_file) logger.info(f"Renamed {file} to {new_file}") # Rename corresponding video files for folder in video_folders: video_file = os.path.join( videos_folder, folder, f"episode_{episode_numbers[i]:06d}.mp4" ) new_video_file = os.path.join( videos_folder, folder, f"episode_{new_episode_number:06d}.mp4" ) if os.path.exists(video_file): os.rename(video_file, new_video_file) logger.info(f"Renamed {video_file} to {new_video_file}") # Update list after renaming parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) parquet_files.sort( key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1)) ) logger.info("Updated parquet files list after renaming") # Process each parquet file total_index = 0 for file_path in parquet_files: filename = os.path.basename(file_path) match = re.search(r"episode_(\d+)\.parquet", filename) if match: episode_number = int(match.group(1)) logger.info(f"Processing {filename} - Episode {episode_number}") try: df = pd.read_parquet(file_path, engine="pyarrow") df["episode_index"] = episode_number df["frame_index"] = range(len(df)) df["index"] = range(total_index, total_index + len(df)) total_index += len(df) df.to_parquet(file_path, index=False) logger.info(f"Successfully updated {filename}") except Exception as e: raise RuntimeError(f"Error processing {filename}: {str(e)}") else: logger.info(f"Skipping {filename} - doesn't match expected pattern") logger.info("Parquet processing complete") def run_stats_computation(dataset_path: str): """Run the lerobot stats computation script""" script_path = "lerobot_stats_compute.py" if not os.path.exists(script_path): logger.warning(f"Stats script '{script_path}' not found, skipping stats computation") return logger.info("Running lerobot_stats_compute.py...") try: subprocess.run( ["uv", "run", script_path, "--dataset-path", dataset_path], check=True, ) logger.info(f"Successfully executed {script_path}") except subprocess.CalledProcessError as e: logger.warning(f"Error executing stats script: {str(e)}") except FileNotFoundError: logger.warning("uv not found, skipping stats computation") def delete_episodes_and_repair( dataset_path: str, episode_indexes: List[int], run_stats: bool = True, ) -> str: """Delete specified episodes and repair the dataset. Args: dataset_path: Path to the dataset episode_indexes: List of episode indexes to delete run_stats: Whether to run stats computation after repair Returns: Path to the repaired dataset """ if not episode_indexes: raise ValueError("No episode indexes provided for deletion") # Check v2.0 format check_v2_format(dataset_path) logger.info(f"Deleting episodes: {episode_indexes}") # Delete .DS_Store files delete_ds_store(dataset_path) # Delete episode files delete_episode_files(dataset_path, episode_indexes) # Update meta JSONL files (episodes.jsonl and episodes_stats.jsonl) update_meta_jsonl_files(dataset_path, episode_indexes) # Process and repair remaining parquet files process_parquet_files(dataset_path) # Update info.json with new episode and video counts update_info_counts(dataset_path) # Run stats computation if run_stats: run_stats_computation(dataset_path) logger.info("Episode deletion and repair complete") return dataset_path def upload_dataset( local_dir: str, dest_repo_id: str, hf_token: Optional[str] = None, commit_message: Optional[str] = None, private: bool = False, ) -> str: """Upload a local dataset folder to a destination HF dataset repo. Returns the repo URL/identifier. """ if not dest_repo_id: raise ValueError("dest_repo_id must be provided") token = hf_token or os.environ.get("HF_TOKEN") create_repo( repo_id=dest_repo_id, repo_type="dataset", private=private, exist_ok=True, token=token, ) _enable_hf_transfer() msg = commit_message or "Updated dataset after episode deletion" logger.info(f"Uploading '{local_dir}' to '{dest_repo_id}' (private={private}) ...") upload_folder( repo_id=dest_repo_id, repo_type="dataset", folder_path=local_dir, path_in_repo=".", commit_message=msg, token=token, ) logger.info(f"Uploaded to: {dest_repo_id}") return dest_repo_id