delete-episodes-from-dataset / delete_episodes.py
SuveenE's picture
Update episodes.jsonl and episodes_stats.jsonl
abb873c
import os
import re
import sys
import glob
import json
import logging
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional, Tuple
from huggingface_hub import snapshot_download, upload_folder, create_repo
import pandas as pd
logger = logging.getLogger(__name__)
if not logger.handlers:
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
def _enable_hf_transfer():
"""Enable hf_transfer acceleration if the package is installed"""
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1":
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
logger.info("Enabled hf_transfer acceleration (HF_HUB_ENABLE_HF_TRANSFER=1)")
def download_dataset(
repo_id: str,
local_dir: str,
hf_token: Optional[str] = None,
) -> str:
"""Download a Hugging Face dataset by repo_id.
Returns the local directory path.
"""
_enable_hf_transfer()
local_path = Path(local_dir)
local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Downloading dataset '{repo_id}' to '{local_dir}' ...")
path = snapshot_download(
repo_id=repo_id,
repo_type="dataset",
token=hf_token,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
)
logger.info(f"Downloaded: {repo_id} -> {path}")
return str(local_path)
def check_v2_format(dataset_path: str) -> bool:
"""Check if dataset is in v2.x format"""
info_path = os.path.join(dataset_path, "meta", "info.json")
if not os.path.exists(info_path):
raise ValueError(f"Error: {info_path} does not exist")
with open(info_path, "r") as f:
try:
info = json.load(f)
if "codebase_version" not in info:
raise ValueError(f"Error: {info_path} is not a valid v2.x dataset")
version = info["codebase_version"]
# Accept any v2.x version (v2.0, v2.1, etc.)
if not version.startswith("v2."):
raise ValueError(
f"Error: {info_path} is not a v2.x dataset, found {version}"
)
logger.info(f"Dataset version: {version}")
return True
except json.JSONDecodeError:
raise ValueError(f"Error: {info_path} is not a valid JSON file")
def update_info_counts(dataset_path: str):
"""Update total_episodes and total_videos counts in info.json to reflect actual counts.
Args:
dataset_path: Path to the dataset
"""
info_path = os.path.join(dataset_path, "meta", "info.json")
if not os.path.exists(info_path):
raise ValueError(f"Error: {info_path} does not exist")
logger.info("Updating info.json counts to reflect actual dataset state...")
# Count actual episodes
episodes = list_episodes(dataset_path)
new_episode_count = len(episodes)
# Count actual videos
videos_folder = os.path.join(dataset_path, "videos", "chunk-000")
video_count = 0
if os.path.exists(videos_folder):
video_folders = [d for d in os.listdir(videos_folder)
if os.path.isdir(os.path.join(videos_folder, d))]
for folder in video_folders:
video_files = glob.glob(
os.path.join(videos_folder, folder, "episode_*.mp4")
)
video_count += len(video_files)
# Read and update info.json
with open(info_path, "r") as f:
info = json.load(f)
old_episodes = info.get("total_episodes", 0)
old_videos = info.get("total_videos", 0)
info["total_episodes"] = new_episode_count
info["total_videos"] = video_count
with open(info_path, "w") as f:
json.dump(info, f, indent=4)
logger.info(
f"Updated total_episodes: {old_episodes}{new_episode_count}"
)
logger.info(
f"Updated total_videos: {old_videos}{video_count}"
)
def list_episodes(dataset_path: str) -> List[int]:
"""List all episode numbers in the dataset"""
parquets_folder = os.path.join(dataset_path, "data", "chunk-000")
if not os.path.exists(parquets_folder):
return []
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet"))
episode_numbers = []
for file in parquet_files:
match = re.search(r"episode_(\d+)\.parquet", file)
if match:
episode_numbers.append(int(match.group(1)))
return sorted(episode_numbers)
def delete_ds_store(dataset_path: str):
"""Delete all .DS_Store files in the given dataset path and its subdirectories"""
logger.info("Deleting .DS_Store files...")
ds_store_files = glob.glob(
os.path.join(dataset_path, "**", ".DS_Store"), recursive=True
)
if not ds_store_files:
logger.info("No .DS_Store files found")
return
for file in ds_store_files:
os.remove(file)
logger.info(f"Deleted {file}")
logger.info(".DS_Store files deleted")
def update_meta_jsonl_files(dataset_path: str, indexes_to_delete: List[int]):
"""Update episodes.jsonl and episodes_stats.jsonl by removing deleted episodes and re-indexing"""
meta_folder = os.path.join(dataset_path, "meta")
episodes_file = os.path.join(meta_folder, "episodes.jsonl")
episodes_stats_file = os.path.join(meta_folder, "episodes_stats.jsonl")
# Process episodes.jsonl
if os.path.exists(episodes_file):
logger.info("Updating episodes.jsonl...")
episodes = []
with open(episodes_file, "r") as f:
for line in f:
line = line.strip()
if line: # Skip empty lines
episode = json.loads(line)
if episode["episode_index"] not in indexes_to_delete:
episodes.append(episode)
# Re-index episodes
for new_index, episode in enumerate(episodes):
episode["episode_index"] = new_index
# Write back
with open(episodes_file, "w") as f:
for episode in episodes:
f.write(json.dumps(episode) + "\n")
logger.info(f"Updated episodes.jsonl: {len(episodes)} episodes remaining")
else:
logger.warning(f"episodes.jsonl not found at {episodes_file}")
# Process episodes_stats.jsonl
if os.path.exists(episodes_stats_file):
logger.info("Updating episodes_stats.jsonl...")
stats = []
with open(episodes_stats_file, "r") as f:
for line in f:
line = line.strip()
if line: # Skip empty lines
stat = json.loads(line)
if stat["episode_index"] not in indexes_to_delete:
stats.append(stat)
# Re-index stats
for new_index, stat in enumerate(stats):
stat["episode_index"] = new_index
# Write back
with open(episodes_stats_file, "w") as f:
for stat in stats:
f.write(json.dumps(stat) + "\n")
logger.info(f"Updated episodes_stats.jsonl: {len(stats)} episode stats remaining")
else:
logger.warning(f"episodes_stats.jsonl not found at {episodes_stats_file}")
def delete_episode_files(dataset_path: str, indexes: List[int]):
"""Delete parquet and video files for specified episode indexes"""
parquets_folder = os.path.join(dataset_path, "data", "chunk-000")
videos_folder = os.path.join(dataset_path, "videos", "chunk-000")
# Delete parquet files
logger.info("Deleting parquet files...")
parquet_files = glob.glob(os.path.join(parquets_folder, "*.parquet"))
for index in indexes:
for file in parquet_files:
if f"episode_{index:06d}.parquet" in file:
os.remove(file)
logger.info(f"Deleted file {file}")
# Delete video files
logger.info("Deleting video files...")
if os.path.exists(videos_folder):
video_folders = os.listdir(videos_folder)
for index in indexes:
for folder in video_folders:
video_files = glob.glob(
os.path.join(videos_folder, folder, f"episode_{index:06d}.mp4")
)
for video_file in video_files:
os.remove(video_file)
logger.info(f"Deleted file {video_file}")
def process_parquet_files(dataset_path: str):
"""Process all parquet files by correcting the episode_index column"""
parquets_folder = os.path.join(dataset_path, "data", "chunk-000")
videos_folder = os.path.join(dataset_path, "videos", "chunk-000")
logger.info("Processing parquet files...")
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet"))
if not parquet_files:
logger.info(f"No parquet files found in {parquets_folder}")
return
logger.info(f"Found {len(parquet_files)} parquet files to process")
# Order files by episode number
parquet_files.sort(
key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1))
)
# Check if episode numbers are continuous
episode_numbers = [
int(re.search(r"episode_(\d+)\.parquet", file).group(1))
for file in parquet_files
]
episode_numbers.sort()
# Get video folders if they exist
video_folders = []
if os.path.exists(videos_folder):
video_folders = os.listdir(videos_folder)
if episode_numbers != list(range(len(episode_numbers))):
logger.info(
"Episode numbers are not continuous or starting from 0. Renaming files and videos..."
)
for i, file in enumerate(parquet_files):
new_episode_number = i
new_file = os.path.join(
parquets_folder, f"episode_{new_episode_number:06d}.parquet"
)
os.rename(file, new_file)
logger.info(f"Renamed {file} to {new_file}")
# Rename corresponding video files
for folder in video_folders:
video_file = os.path.join(
videos_folder, folder, f"episode_{episode_numbers[i]:06d}.mp4"
)
new_video_file = os.path.join(
videos_folder, folder, f"episode_{new_episode_number:06d}.mp4"
)
if os.path.exists(video_file):
os.rename(video_file, new_video_file)
logger.info(f"Renamed {video_file} to {new_video_file}")
# Update list after renaming
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet"))
parquet_files.sort(
key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1))
)
logger.info("Updated parquet files list after renaming")
# Process each parquet file
total_index = 0
for file_path in parquet_files:
filename = os.path.basename(file_path)
match = re.search(r"episode_(\d+)\.parquet", filename)
if match:
episode_number = int(match.group(1))
logger.info(f"Processing {filename} - Episode {episode_number}")
try:
df = pd.read_parquet(file_path, engine="pyarrow")
df["episode_index"] = episode_number
df["frame_index"] = range(len(df))
df["index"] = range(total_index, total_index + len(df))
total_index += len(df)
df.to_parquet(file_path, index=False)
logger.info(f"Successfully updated {filename}")
except Exception as e:
raise RuntimeError(f"Error processing {filename}: {str(e)}")
else:
logger.info(f"Skipping {filename} - doesn't match expected pattern")
logger.info("Parquet processing complete")
def run_stats_computation(dataset_path: str):
"""Run the lerobot stats computation script"""
script_path = "lerobot_stats_compute.py"
if not os.path.exists(script_path):
logger.warning(f"Stats script '{script_path}' not found, skipping stats computation")
return
logger.info("Running lerobot_stats_compute.py...")
try:
subprocess.run(
["uv", "run", script_path, "--dataset-path", dataset_path],
check=True,
)
logger.info(f"Successfully executed {script_path}")
except subprocess.CalledProcessError as e:
logger.warning(f"Error executing stats script: {str(e)}")
except FileNotFoundError:
logger.warning("uv not found, skipping stats computation")
def delete_episodes_and_repair(
dataset_path: str,
episode_indexes: List[int],
run_stats: bool = True,
) -> str:
"""Delete specified episodes and repair the dataset.
Args:
dataset_path: Path to the dataset
episode_indexes: List of episode indexes to delete
run_stats: Whether to run stats computation after repair
Returns:
Path to the repaired dataset
"""
if not episode_indexes:
raise ValueError("No episode indexes provided for deletion")
# Check v2.0 format
check_v2_format(dataset_path)
logger.info(f"Deleting episodes: {episode_indexes}")
# Delete .DS_Store files
delete_ds_store(dataset_path)
# Delete episode files
delete_episode_files(dataset_path, episode_indexes)
# Update meta JSONL files (episodes.jsonl and episodes_stats.jsonl)
update_meta_jsonl_files(dataset_path, episode_indexes)
# Process and repair remaining parquet files
process_parquet_files(dataset_path)
# Update info.json with new episode and video counts
update_info_counts(dataset_path)
# Run stats computation
if run_stats:
run_stats_computation(dataset_path)
logger.info("Episode deletion and repair complete")
return dataset_path
def upload_dataset(
local_dir: str,
dest_repo_id: str,
hf_token: Optional[str] = None,
commit_message: Optional[str] = None,
private: bool = False,
) -> str:
"""Upload a local dataset folder to a destination HF dataset repo.
Returns the repo URL/identifier.
"""
if not dest_repo_id:
raise ValueError("dest_repo_id must be provided")
token = hf_token or os.environ.get("HF_TOKEN")
create_repo(
repo_id=dest_repo_id,
repo_type="dataset",
private=private,
exist_ok=True,
token=token,
)
_enable_hf_transfer()
msg = commit_message or "Updated dataset after episode deletion"
logger.info(f"Uploading '{local_dir}' to '{dest_repo_id}' (private={private}) ...")
upload_folder(
repo_id=dest_repo_id,
repo_type="dataset",
folder_path=local_dir,
path_in_repo=".",
commit_message=msg,
token=token,
)
logger.info(f"Uploaded to: {dest_repo_id}")
return dest_repo_id