Spaces:
Running
Running
| import os | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import zipfile | |
| import shutil | |
| import tempfile | |
| def download_and_unzip_hf_file(repo_id: str, filename: str, destination_dir: str): | |
| """ | |
| Downloads a file from a Hugging Face dataset repository, and moves its contents to the destination directory. | |
| Args: | |
| repo_id (str): The Hugging Face repository ID (e.g., "jaxaht/eval-teammates"). | |
| filename (str): The name of the file to download from the repository. | |
| destination_dir (str): The directory where the file should be unzipped. | |
| Returns: | |
| bool: True if successful, False otherwise. | |
| """ | |
| print(f"Starting download & extraction: {repo_id}/{filename} -> {destination_dir}") | |
| os.makedirs(destination_dir, exist_ok=True) | |
| try: | |
| # Download the file from Hugging Face Hub (specify repo_type="dataset" for dataset repositories) | |
| downloaded_file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") | |
| print(f"Downloaded {filename} to {downloaded_file_path}") | |
| except Exception as e: | |
| print(f"Error during hf_hub_download for {repo_id}/{filename}: {e}") | |
| return False | |
| if not os.path.exists(downloaded_file_path) or os.path.getsize(downloaded_file_path) == 0: | |
| print(f"Error: Download failed or file is empty: {downloaded_file_path}") | |
| return False | |
| downloaded_size = os.path.getsize(downloaded_file_path) | |
| print(f"Downloaded {downloaded_file_path} ({downloaded_size} bytes).") | |
| temp_dir_for_extraction = tempfile.mkdtemp() | |
| try: | |
| print(f"Unzipping {downloaded_file_path} to temporary directory {temp_dir_for_extraction}...") | |
| with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref: | |
| zip_ref.extractall(temp_dir_for_extraction) | |
| print(f"Successfully unzipped {downloaded_file_path} to {temp_dir_for_extraction}.") | |
| # Determine the source of files to move | |
| extracted_items = os.listdir(temp_dir_for_extraction) | |
| source_path_for_moving = temp_dir_for_extraction | |
| if len(extracted_items) == 1: | |
| potential_single_folder = os.path.join(temp_dir_for_extraction, extracted_items[0]) | |
| if os.path.isdir(potential_single_folder): | |
| source_path_for_moving = potential_single_folder | |
| # Ensure final destination directory exists | |
| os.makedirs(destination_dir, exist_ok=True) | |
| print(f"Processing and moving files from '{source_path_for_moving}' to '{destination_dir}'...") | |
| files_moved_count = 0 | |
| # os.walk will iterate through all files and directories in source_path_for_moving | |
| for root, _, files_in_dir in os.walk(source_path_for_moving): | |
| for filename in files_in_dir: | |
| src_file_full_path = os.path.join(root, filename) | |
| # Determine the path of the file relative to the source_path_for_moving | |
| # This relative path will be used to construct the destination path | |
| relative_path_to_file = os.path.relpath(src_file_full_path, source_path_for_moving) | |
| dst_file_full_path = os.path.join(destination_dir, relative_path_to_file) | |
| # Ensure the parent directory for the destination file exists | |
| dst_file_parent_dir = os.path.dirname(dst_file_full_path) | |
| os.makedirs(dst_file_parent_dir, exist_ok=True) | |
| if os.path.isfile(dst_file_full_path): | |
| print(f"Warning: Overwriting existing file '{dst_file_full_path}'.") | |
| shutil.move(src_file_full_path, dst_file_full_path) | |
| files_moved_count += 1 | |
| if files_moved_count > 0: | |
| print(f"Successfully moved {files_moved_count} file(s) to {destination_dir}.") | |
| else: | |
| # Provide a more specific note if no files were moved. | |
| if not extracted_items: # Nothing was extracted from the zip initially | |
| print(f"Note: The zip file '{filename}' appears to be completely empty.") | |
| elif source_path_for_moving != temp_dir_for_extraction and not os.listdir(source_path_for_moving): | |
| # This means a single root folder was identified, and it was empty. | |
| print(f"Note: The single root folder '{os.path.basename(source_path_for_moving)}' (from zip) was empty, so no files were moved.") | |
| else: # Zip either contained only empty directories, or the structure didn't yield files from source_path_for_moving | |
| print(f"Note: No files found to move from '{source_path_for_moving}'. The zip may have contained only empty directories.") | |
| return True | |
| except zipfile.BadZipFile: | |
| print(f"Error: File {downloaded_file_path} (size: {downloaded_size} bytes) is not a valid zip file.") | |
| return False | |
| except Exception as e_unzip: | |
| # This catches other errors during unzipping or the file moving logic. | |
| print(f"Error during unzipping or moving of {downloaded_file_path}: {e_unzip}") | |
| return False | |
| finally: | |
| # Always try to clean up the temporary extraction directory | |
| if os.path.exists(temp_dir_for_extraction): | |
| print(f"Cleaning up temporary extraction directory: {temp_dir_for_extraction}") | |
| shutil.rmtree(temp_dir_for_extraction) | |
| def download_hf_directory(repo_id: str, remote_dir: str, destination_dir: str): | |
| """ | |
| Downloads a directory from a Hugging Face dataset repository to a local directory, | |
| preserving the remote directory structure under destination_dir. | |
| Args: | |
| repo_id (str): The Hugging Face repository ID (e.g., "jaxaht/eval-teammates"). | |
| remote_dir (str): The directory in the HF repo to download (e.g., "lbf"). | |
| destination_dir (str): Local directory to download into; remote_dir becomes a | |
| subdirectory of this (e.g., destination_dir/lbf/...). | |
| Returns: | |
| bool: True if successful, False otherwise. | |
| """ | |
| print(f"Starting download: {repo_id}/{remote_dir} -> {destination_dir}/{remote_dir}") | |
| os.makedirs(destination_dir, exist_ok=True) | |
| try: | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| local_dir=destination_dir, | |
| allow_patterns=f"{remote_dir}/**", | |
| ) | |
| print(f"Successfully downloaded {remote_dir} to {destination_dir}.") | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading {repo_id}/{remote_dir}: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| default_repo_id = "jaxaht/eval-teammates" | |
| data_files = { | |
| "best_returns_teammates": { | |
| "type": "zip", | |
| "filename": "best_heldout_returns.zip", | |
| "target_directory": "results/", | |
| }, | |
| "lbf_teammates": { | |
| "type": "dir", | |
| "filename": "lbf", | |
| "target_directory": "eval_teammates/", | |
| }, | |
| "lbf_12x12_teammates": { | |
| "type": "dir", | |
| "filename": "lbf_12x12", | |
| "target_directory": "eval_teammates/", | |
| }, | |
| "overcooked-v1_teammates": { | |
| "type": "dir", | |
| "filename": "overcooked-v1", | |
| "target_directory": "eval_teammates/", | |
| }, | |
| "hanabi_teammates": { | |
| "type": "dir", | |
| "filename": "hanabi", | |
| "target_directory": "eval_teammates/", | |
| "repo_id": "lainwired/jaxaht-hanabi", | |
| }, | |
| "hanabi_obl_weights": { | |
| "type": "dir", | |
| "filename": "obl-r2d2-flax", | |
| "target_directory": "agents/hanabi/", | |
| "repo_id": "lainwired/jaxaht-hanabi", | |
| }, | |
| "hanabi_bc_weights": { | |
| "type": "dir", | |
| "filename": "bc_weights", | |
| "target_directory": "agents/", | |
| "repo_id": "lainwired/jaxaht-hanabi", | |
| }, | |
| } | |
| for data_name, data_info in data_files.items(): | |
| repo_id = data_info.get("repo_id", default_repo_id) | |
| if data_info["type"] == "zip": | |
| success = download_and_unzip_hf_file( | |
| repo_id=repo_id, | |
| filename=data_info["filename"], | |
| destination_dir=data_info["target_directory"], | |
| ) | |
| else: | |
| success = download_hf_directory( | |
| repo_id=repo_id, | |
| remote_dir=data_info["filename"], | |
| destination_dir=data_info["target_directory"], | |
| ) | |
| if success: | |
| print(f"Download completed successfully for {data_name}.") | |
| else: | |
| print(f"Download failed for {data_name}.") | |