jaxaht-benchmark / download_eval_data.py
lainwired's picture
re-deploy: study.tsx full play.tsx parity + runtime weight download
cd20d20 verified
import os
from huggingface_hub import hf_hub_download, snapshot_download
import zipfile
import shutil
import tempfile
def download_and_unzip_hf_file(repo_id: str, filename: str, destination_dir: str):
"""
Downloads a file from a Hugging Face dataset repository, and moves its contents to the destination directory.
Args:
repo_id (str): The Hugging Face repository ID (e.g., "jaxaht/eval-teammates").
filename (str): The name of the file to download from the repository.
destination_dir (str): The directory where the file should be unzipped.
Returns:
bool: True if successful, False otherwise.
"""
print(f"Starting download & extraction: {repo_id}/{filename} -> {destination_dir}")
os.makedirs(destination_dir, exist_ok=True)
try:
# Download the file from Hugging Face Hub (specify repo_type="dataset" for dataset repositories)
downloaded_file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
print(f"Downloaded {filename} to {downloaded_file_path}")
except Exception as e:
print(f"Error during hf_hub_download for {repo_id}/{filename}: {e}")
return False
if not os.path.exists(downloaded_file_path) or os.path.getsize(downloaded_file_path) == 0:
print(f"Error: Download failed or file is empty: {downloaded_file_path}")
return False
downloaded_size = os.path.getsize(downloaded_file_path)
print(f"Downloaded {downloaded_file_path} ({downloaded_size} bytes).")
temp_dir_for_extraction = tempfile.mkdtemp()
try:
print(f"Unzipping {downloaded_file_path} to temporary directory {temp_dir_for_extraction}...")
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir_for_extraction)
print(f"Successfully unzipped {downloaded_file_path} to {temp_dir_for_extraction}.")
# Determine the source of files to move
extracted_items = os.listdir(temp_dir_for_extraction)
source_path_for_moving = temp_dir_for_extraction
if len(extracted_items) == 1:
potential_single_folder = os.path.join(temp_dir_for_extraction, extracted_items[0])
if os.path.isdir(potential_single_folder):
source_path_for_moving = potential_single_folder
# Ensure final destination directory exists
os.makedirs(destination_dir, exist_ok=True)
print(f"Processing and moving files from '{source_path_for_moving}' to '{destination_dir}'...")
files_moved_count = 0
# os.walk will iterate through all files and directories in source_path_for_moving
for root, _, files_in_dir in os.walk(source_path_for_moving):
for filename in files_in_dir:
src_file_full_path = os.path.join(root, filename)
# Determine the path of the file relative to the source_path_for_moving
# This relative path will be used to construct the destination path
relative_path_to_file = os.path.relpath(src_file_full_path, source_path_for_moving)
dst_file_full_path = os.path.join(destination_dir, relative_path_to_file)
# Ensure the parent directory for the destination file exists
dst_file_parent_dir = os.path.dirname(dst_file_full_path)
os.makedirs(dst_file_parent_dir, exist_ok=True)
if os.path.isfile(dst_file_full_path):
print(f"Warning: Overwriting existing file '{dst_file_full_path}'.")
shutil.move(src_file_full_path, dst_file_full_path)
files_moved_count += 1
if files_moved_count > 0:
print(f"Successfully moved {files_moved_count} file(s) to {destination_dir}.")
else:
# Provide a more specific note if no files were moved.
if not extracted_items: # Nothing was extracted from the zip initially
print(f"Note: The zip file '{filename}' appears to be completely empty.")
elif source_path_for_moving != temp_dir_for_extraction and not os.listdir(source_path_for_moving):
# This means a single root folder was identified, and it was empty.
print(f"Note: The single root folder '{os.path.basename(source_path_for_moving)}' (from zip) was empty, so no files were moved.")
else: # Zip either contained only empty directories, or the structure didn't yield files from source_path_for_moving
print(f"Note: No files found to move from '{source_path_for_moving}'. The zip may have contained only empty directories.")
return True
except zipfile.BadZipFile:
print(f"Error: File {downloaded_file_path} (size: {downloaded_size} bytes) is not a valid zip file.")
return False
except Exception as e_unzip:
# This catches other errors during unzipping or the file moving logic.
print(f"Error during unzipping or moving of {downloaded_file_path}: {e_unzip}")
return False
finally:
# Always try to clean up the temporary extraction directory
if os.path.exists(temp_dir_for_extraction):
print(f"Cleaning up temporary extraction directory: {temp_dir_for_extraction}")
shutil.rmtree(temp_dir_for_extraction)
def download_hf_directory(repo_id: str, remote_dir: str, destination_dir: str):
"""
Downloads a directory from a Hugging Face dataset repository to a local directory,
preserving the remote directory structure under destination_dir.
Args:
repo_id (str): The Hugging Face repository ID (e.g., "jaxaht/eval-teammates").
remote_dir (str): The directory in the HF repo to download (e.g., "lbf").
destination_dir (str): Local directory to download into; remote_dir becomes a
subdirectory of this (e.g., destination_dir/lbf/...).
Returns:
bool: True if successful, False otherwise.
"""
print(f"Starting download: {repo_id}/{remote_dir} -> {destination_dir}/{remote_dir}")
os.makedirs(destination_dir, exist_ok=True)
try:
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=destination_dir,
allow_patterns=f"{remote_dir}/**",
)
print(f"Successfully downloaded {remote_dir} to {destination_dir}.")
return True
except Exception as e:
print(f"Error downloading {repo_id}/{remote_dir}: {e}")
return False
if __name__ == "__main__":
default_repo_id = "jaxaht/eval-teammates"
data_files = {
"best_returns_teammates": {
"type": "zip",
"filename": "best_heldout_returns.zip",
"target_directory": "results/",
},
"lbf_teammates": {
"type": "dir",
"filename": "lbf",
"target_directory": "eval_teammates/",
},
"lbf_12x12_teammates": {
"type": "dir",
"filename": "lbf_12x12",
"target_directory": "eval_teammates/",
},
"overcooked-v1_teammates": {
"type": "dir",
"filename": "overcooked-v1",
"target_directory": "eval_teammates/",
},
"hanabi_teammates": {
"type": "dir",
"filename": "hanabi",
"target_directory": "eval_teammates/",
"repo_id": "lainwired/jaxaht-hanabi",
},
"hanabi_obl_weights": {
"type": "dir",
"filename": "obl-r2d2-flax",
"target_directory": "agents/hanabi/",
"repo_id": "lainwired/jaxaht-hanabi",
},
"hanabi_bc_weights": {
"type": "dir",
"filename": "bc_weights",
"target_directory": "agents/",
"repo_id": "lainwired/jaxaht-hanabi",
},
}
for data_name, data_info in data_files.items():
repo_id = data_info.get("repo_id", default_repo_id)
if data_info["type"] == "zip":
success = download_and_unzip_hf_file(
repo_id=repo_id,
filename=data_info["filename"],
destination_dir=data_info["target_directory"],
)
else:
success = download_hf_directory(
repo_id=repo_id,
remote_dir=data_info["filename"],
destination_dir=data_info["target_directory"],
)
if success:
print(f"Download completed successfully for {data_name}.")
else:
print(f"Download failed for {data_name}.")