|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import glob |
|
|
import json |
|
|
import logging |
|
|
import shutil |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
from huggingface_hub import snapshot_download, upload_folder, create_repo |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
if not logger.handlers: |
|
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") |
|
|
|
|
|
|
|
|
def _enable_hf_transfer(): |
|
|
"""Enable hf_transfer acceleration if the package is installed""" |
|
|
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1": |
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
logger.info("Enabled hf_transfer acceleration (HF_HUB_ENABLE_HF_TRANSFER=1)") |
|
|
|
|
|
|
|
|
def download_dataset( |
|
|
repo_id: str, |
|
|
local_dir: str, |
|
|
hf_token: Optional[str] = None, |
|
|
) -> str: |
|
|
"""Download a Hugging Face dataset by repo_id. |
|
|
|
|
|
Returns the local directory path. |
|
|
""" |
|
|
_enable_hf_transfer() |
|
|
|
|
|
local_path = Path(local_dir) |
|
|
local_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Downloading dataset '{repo_id}' to '{local_dir}' ...") |
|
|
|
|
|
path = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
repo_type="dataset", |
|
|
token=hf_token, |
|
|
local_dir=str(local_dir), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
logger.info(f"Downloaded: {repo_id} -> {path}") |
|
|
return str(local_path) |
|
|
|
|
|
|
|
|
def check_v2_format(dataset_path: str) -> bool: |
|
|
"""Check if dataset is in v2.x format""" |
|
|
info_path = os.path.join(dataset_path, "meta", "info.json") |
|
|
|
|
|
if not os.path.exists(info_path): |
|
|
raise ValueError(f"Error: {info_path} does not exist") |
|
|
|
|
|
with open(info_path, "r") as f: |
|
|
try: |
|
|
info = json.load(f) |
|
|
if "codebase_version" not in info: |
|
|
raise ValueError(f"Error: {info_path} is not a valid v2.x dataset") |
|
|
|
|
|
version = info["codebase_version"] |
|
|
|
|
|
if not version.startswith("v2."): |
|
|
raise ValueError( |
|
|
f"Error: {info_path} is not a v2.x dataset, found {version}" |
|
|
) |
|
|
|
|
|
logger.info(f"Dataset version: {version}") |
|
|
return True |
|
|
except json.JSONDecodeError: |
|
|
raise ValueError(f"Error: {info_path} is not a valid JSON file") |
|
|
|
|
|
|
|
|
def update_info_counts(dataset_path: str): |
|
|
"""Update total_episodes and total_videos counts in info.json to reflect actual counts. |
|
|
|
|
|
Args: |
|
|
dataset_path: Path to the dataset |
|
|
""" |
|
|
info_path = os.path.join(dataset_path, "meta", "info.json") |
|
|
|
|
|
if not os.path.exists(info_path): |
|
|
raise ValueError(f"Error: {info_path} does not exist") |
|
|
|
|
|
logger.info("Updating info.json counts to reflect actual dataset state...") |
|
|
|
|
|
|
|
|
episodes = list_episodes(dataset_path) |
|
|
new_episode_count = len(episodes) |
|
|
|
|
|
|
|
|
videos_folder = os.path.join(dataset_path, "videos", "chunk-000") |
|
|
video_count = 0 |
|
|
if os.path.exists(videos_folder): |
|
|
video_folders = [d for d in os.listdir(videos_folder) |
|
|
if os.path.isdir(os.path.join(videos_folder, d))] |
|
|
for folder in video_folders: |
|
|
video_files = glob.glob( |
|
|
os.path.join(videos_folder, folder, "episode_*.mp4") |
|
|
) |
|
|
video_count += len(video_files) |
|
|
|
|
|
|
|
|
with open(info_path, "r") as f: |
|
|
info = json.load(f) |
|
|
|
|
|
old_episodes = info.get("total_episodes", 0) |
|
|
old_videos = info.get("total_videos", 0) |
|
|
|
|
|
info["total_episodes"] = new_episode_count |
|
|
info["total_videos"] = video_count |
|
|
|
|
|
with open(info_path, "w") as f: |
|
|
json.dump(info, f, indent=4) |
|
|
|
|
|
logger.info( |
|
|
f"Updated total_episodes: {old_episodes} → {new_episode_count}" |
|
|
) |
|
|
logger.info( |
|
|
f"Updated total_videos: {old_videos} → {video_count}" |
|
|
) |
|
|
|
|
|
|
|
|
def list_episodes(dataset_path: str) -> List[int]: |
|
|
"""List all episode numbers in the dataset""" |
|
|
parquets_folder = os.path.join(dataset_path, "data", "chunk-000") |
|
|
|
|
|
if not os.path.exists(parquets_folder): |
|
|
return [] |
|
|
|
|
|
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) |
|
|
|
|
|
episode_numbers = [] |
|
|
for file in parquet_files: |
|
|
match = re.search(r"episode_(\d+)\.parquet", file) |
|
|
if match: |
|
|
episode_numbers.append(int(match.group(1))) |
|
|
|
|
|
return sorted(episode_numbers) |
|
|
|
|
|
|
|
|
def delete_ds_store(dataset_path: str): |
|
|
"""Delete all .DS_Store files in the given dataset path and its subdirectories""" |
|
|
logger.info("Deleting .DS_Store files...") |
|
|
ds_store_files = glob.glob( |
|
|
os.path.join(dataset_path, "**", ".DS_Store"), recursive=True |
|
|
) |
|
|
|
|
|
if not ds_store_files: |
|
|
logger.info("No .DS_Store files found") |
|
|
return |
|
|
|
|
|
for file in ds_store_files: |
|
|
os.remove(file) |
|
|
logger.info(f"Deleted {file}") |
|
|
|
|
|
logger.info(".DS_Store files deleted") |
|
|
|
|
|
|
|
|
def update_meta_jsonl_files(dataset_path: str, indexes_to_delete: List[int]): |
|
|
"""Update episodes.jsonl and episodes_stats.jsonl by removing deleted episodes and re-indexing""" |
|
|
meta_folder = os.path.join(dataset_path, "meta") |
|
|
episodes_file = os.path.join(meta_folder, "episodes.jsonl") |
|
|
episodes_stats_file = os.path.join(meta_folder, "episodes_stats.jsonl") |
|
|
|
|
|
|
|
|
if os.path.exists(episodes_file): |
|
|
logger.info("Updating episodes.jsonl...") |
|
|
episodes = [] |
|
|
with open(episodes_file, "r") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
episode = json.loads(line) |
|
|
if episode["episode_index"] not in indexes_to_delete: |
|
|
episodes.append(episode) |
|
|
|
|
|
|
|
|
for new_index, episode in enumerate(episodes): |
|
|
episode["episode_index"] = new_index |
|
|
|
|
|
|
|
|
with open(episodes_file, "w") as f: |
|
|
for episode in episodes: |
|
|
f.write(json.dumps(episode) + "\n") |
|
|
|
|
|
logger.info(f"Updated episodes.jsonl: {len(episodes)} episodes remaining") |
|
|
else: |
|
|
logger.warning(f"episodes.jsonl not found at {episodes_file}") |
|
|
|
|
|
|
|
|
if os.path.exists(episodes_stats_file): |
|
|
logger.info("Updating episodes_stats.jsonl...") |
|
|
stats = [] |
|
|
with open(episodes_stats_file, "r") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
stat = json.loads(line) |
|
|
if stat["episode_index"] not in indexes_to_delete: |
|
|
stats.append(stat) |
|
|
|
|
|
|
|
|
for new_index, stat in enumerate(stats): |
|
|
stat["episode_index"] = new_index |
|
|
|
|
|
|
|
|
with open(episodes_stats_file, "w") as f: |
|
|
for stat in stats: |
|
|
f.write(json.dumps(stat) + "\n") |
|
|
|
|
|
logger.info(f"Updated episodes_stats.jsonl: {len(stats)} episode stats remaining") |
|
|
else: |
|
|
logger.warning(f"episodes_stats.jsonl not found at {episodes_stats_file}") |
|
|
|
|
|
|
|
|
def delete_episode_files(dataset_path: str, indexes: List[int]): |
|
|
"""Delete parquet and video files for specified episode indexes""" |
|
|
parquets_folder = os.path.join(dataset_path, "data", "chunk-000") |
|
|
videos_folder = os.path.join(dataset_path, "videos", "chunk-000") |
|
|
|
|
|
|
|
|
logger.info("Deleting parquet files...") |
|
|
parquet_files = glob.glob(os.path.join(parquets_folder, "*.parquet")) |
|
|
for index in indexes: |
|
|
for file in parquet_files: |
|
|
if f"episode_{index:06d}.parquet" in file: |
|
|
os.remove(file) |
|
|
logger.info(f"Deleted file {file}") |
|
|
|
|
|
|
|
|
logger.info("Deleting video files...") |
|
|
if os.path.exists(videos_folder): |
|
|
video_folders = os.listdir(videos_folder) |
|
|
for index in indexes: |
|
|
for folder in video_folders: |
|
|
video_files = glob.glob( |
|
|
os.path.join(videos_folder, folder, f"episode_{index:06d}.mp4") |
|
|
) |
|
|
for video_file in video_files: |
|
|
os.remove(video_file) |
|
|
logger.info(f"Deleted file {video_file}") |
|
|
|
|
|
|
|
|
def process_parquet_files(dataset_path: str): |
|
|
"""Process all parquet files by correcting the episode_index column""" |
|
|
parquets_folder = os.path.join(dataset_path, "data", "chunk-000") |
|
|
videos_folder = os.path.join(dataset_path, "videos", "chunk-000") |
|
|
|
|
|
logger.info("Processing parquet files...") |
|
|
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) |
|
|
|
|
|
if not parquet_files: |
|
|
logger.info(f"No parquet files found in {parquets_folder}") |
|
|
return |
|
|
|
|
|
logger.info(f"Found {len(parquet_files)} parquet files to process") |
|
|
|
|
|
|
|
|
parquet_files.sort( |
|
|
key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1)) |
|
|
) |
|
|
|
|
|
|
|
|
episode_numbers = [ |
|
|
int(re.search(r"episode_(\d+)\.parquet", file).group(1)) |
|
|
for file in parquet_files |
|
|
] |
|
|
episode_numbers.sort() |
|
|
|
|
|
|
|
|
video_folders = [] |
|
|
if os.path.exists(videos_folder): |
|
|
video_folders = os.listdir(videos_folder) |
|
|
|
|
|
if episode_numbers != list(range(len(episode_numbers))): |
|
|
logger.info( |
|
|
"Episode numbers are not continuous or starting from 0. Renaming files and videos..." |
|
|
) |
|
|
for i, file in enumerate(parquet_files): |
|
|
new_episode_number = i |
|
|
new_file = os.path.join( |
|
|
parquets_folder, f"episode_{new_episode_number:06d}.parquet" |
|
|
) |
|
|
os.rename(file, new_file) |
|
|
logger.info(f"Renamed {file} to {new_file}") |
|
|
|
|
|
|
|
|
for folder in video_folders: |
|
|
video_file = os.path.join( |
|
|
videos_folder, folder, f"episode_{episode_numbers[i]:06d}.mp4" |
|
|
) |
|
|
new_video_file = os.path.join( |
|
|
videos_folder, folder, f"episode_{new_episode_number:06d}.mp4" |
|
|
) |
|
|
if os.path.exists(video_file): |
|
|
os.rename(video_file, new_video_file) |
|
|
logger.info(f"Renamed {video_file} to {new_video_file}") |
|
|
|
|
|
|
|
|
parquet_files = glob.glob(os.path.join(parquets_folder, "episode_*.parquet")) |
|
|
parquet_files.sort( |
|
|
key=lambda x: int(re.search(r"episode_(\d+)\.parquet", x).group(1)) |
|
|
) |
|
|
logger.info("Updated parquet files list after renaming") |
|
|
|
|
|
|
|
|
total_index = 0 |
|
|
for file_path in parquet_files: |
|
|
filename = os.path.basename(file_path) |
|
|
match = re.search(r"episode_(\d+)\.parquet", filename) |
|
|
|
|
|
if match: |
|
|
episode_number = int(match.group(1)) |
|
|
logger.info(f"Processing {filename} - Episode {episode_number}") |
|
|
|
|
|
try: |
|
|
df = pd.read_parquet(file_path, engine="pyarrow") |
|
|
|
|
|
df["episode_index"] = episode_number |
|
|
df["frame_index"] = range(len(df)) |
|
|
df["index"] = range(total_index, total_index + len(df)) |
|
|
total_index += len(df) |
|
|
|
|
|
df.to_parquet(file_path, index=False) |
|
|
logger.info(f"Successfully updated {filename}") |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error processing {filename}: {str(e)}") |
|
|
else: |
|
|
logger.info(f"Skipping {filename} - doesn't match expected pattern") |
|
|
|
|
|
logger.info("Parquet processing complete") |
|
|
|
|
|
|
|
|
def run_stats_computation(dataset_path: str): |
|
|
"""Run the lerobot stats computation script""" |
|
|
script_path = "lerobot_stats_compute.py" |
|
|
|
|
|
if not os.path.exists(script_path): |
|
|
logger.warning(f"Stats script '{script_path}' not found, skipping stats computation") |
|
|
return |
|
|
|
|
|
logger.info("Running lerobot_stats_compute.py...") |
|
|
|
|
|
try: |
|
|
subprocess.run( |
|
|
["uv", "run", script_path, "--dataset-path", dataset_path], |
|
|
check=True, |
|
|
) |
|
|
logger.info(f"Successfully executed {script_path}") |
|
|
except subprocess.CalledProcessError as e: |
|
|
logger.warning(f"Error executing stats script: {str(e)}") |
|
|
except FileNotFoundError: |
|
|
logger.warning("uv not found, skipping stats computation") |
|
|
|
|
|
|
|
|
def delete_episodes_and_repair( |
|
|
dataset_path: str, |
|
|
episode_indexes: List[int], |
|
|
run_stats: bool = True, |
|
|
) -> str: |
|
|
"""Delete specified episodes and repair the dataset. |
|
|
|
|
|
Args: |
|
|
dataset_path: Path to the dataset |
|
|
episode_indexes: List of episode indexes to delete |
|
|
run_stats: Whether to run stats computation after repair |
|
|
|
|
|
Returns: |
|
|
Path to the repaired dataset |
|
|
""" |
|
|
if not episode_indexes: |
|
|
raise ValueError("No episode indexes provided for deletion") |
|
|
|
|
|
|
|
|
check_v2_format(dataset_path) |
|
|
|
|
|
logger.info(f"Deleting episodes: {episode_indexes}") |
|
|
|
|
|
|
|
|
delete_ds_store(dataset_path) |
|
|
|
|
|
|
|
|
delete_episode_files(dataset_path, episode_indexes) |
|
|
|
|
|
|
|
|
update_meta_jsonl_files(dataset_path, episode_indexes) |
|
|
|
|
|
|
|
|
process_parquet_files(dataset_path) |
|
|
|
|
|
|
|
|
update_info_counts(dataset_path) |
|
|
|
|
|
|
|
|
if run_stats: |
|
|
run_stats_computation(dataset_path) |
|
|
|
|
|
logger.info("Episode deletion and repair complete") |
|
|
return dataset_path |
|
|
|
|
|
|
|
|
def upload_dataset( |
|
|
local_dir: str, |
|
|
dest_repo_id: str, |
|
|
hf_token: Optional[str] = None, |
|
|
commit_message: Optional[str] = None, |
|
|
private: bool = False, |
|
|
) -> str: |
|
|
"""Upload a local dataset folder to a destination HF dataset repo. |
|
|
|
|
|
Returns the repo URL/identifier. |
|
|
""" |
|
|
if not dest_repo_id: |
|
|
raise ValueError("dest_repo_id must be provided") |
|
|
|
|
|
token = hf_token or os.environ.get("HF_TOKEN") |
|
|
create_repo( |
|
|
repo_id=dest_repo_id, |
|
|
repo_type="dataset", |
|
|
private=private, |
|
|
exist_ok=True, |
|
|
token=token, |
|
|
) |
|
|
|
|
|
_enable_hf_transfer() |
|
|
msg = commit_message or "Updated dataset after episode deletion" |
|
|
logger.info(f"Uploading '{local_dir}' to '{dest_repo_id}' (private={private}) ...") |
|
|
|
|
|
upload_folder( |
|
|
repo_id=dest_repo_id, |
|
|
repo_type="dataset", |
|
|
folder_path=local_dir, |
|
|
path_in_repo=".", |
|
|
commit_message=msg, |
|
|
token=token, |
|
|
) |
|
|
|
|
|
logger.info(f"Uploaded to: {dest_repo_id}") |
|
|
return dest_repo_id |
|
|
|
|
|
|