| |
|
|
| import argparse |
| import os |
| import sys |
| import subprocess |
| from datasets import load_dataset, DatasetDict, Features, Value |
| from huggingface_hub import HfApi, HfFolder, login, HfApi |
| |
|
|
| |
| def check_git_lfs_installed(): |
| """Checks if git-lfs is installed and configured.""" |
| try: |
| |
| subprocess.run(["git", "lfs", "--version"], check=True, capture_output=True) |
| |
| |
| |
| return True |
| except (subprocess.CalledProcessError, FileNotFoundError): |
| print("Warning: git-lfs command not found or not configured.") |
| print(" Please install git-lfs and run 'git lfs install --system' (or --user).") |
| print(" See: https://git-lfs.com/") |
| |
| |
| return False |
|
|
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Upload CSV dataset splits from a local directory to the Hugging Face Hub." |
| ) |
|
|
| |
| parser.add_argument( |
| "--local_dir", |
| type=str, |
| required=True, |
| help="Path to the local directory containing the dataset CSV files." |
| ) |
| parser.add_argument( |
| "--repo_id", |
| type=str, |
| required=True, |
| help="The Hugging Face Hub repository ID (e.g., 'username/my-equation-dataset')." |
| ) |
| parser.add_argument( |
| "--data_column", |
| type=str, |
| required=True, |
| help="Name of the column in the CSV files containing the actual data (e.g., 'text', 'equation')." |
| ) |
|
|
| |
| parser.add_argument( |
| "--train_filename", |
| type=str, |
| default=None, |
| help="Filename of the training CSV within local_dir (e.g., 'train_data.csv')." |
| ) |
| parser.add_argument( |
| "--val_filename", |
| type=str, |
| default=None, |
| help="Filename of the validation CSV within local_dir (e.g., 'validation_set.csv')." |
| ) |
| parser.add_argument( |
| "--test_filename", |
| type=str, |
| default=None, |
| help="Filename of the test CSV within local_dir (optional, e.g., 'test_examples.csv')." |
| ) |
| parser.add_argument( |
| "--hf_token", |
| type=str, |
| default=None, |
| help="Your Hugging Face Hub access token (with write permissions). If not provided, script will try to use cached token or prompt login." |
| ) |
| parser.add_argument( |
| "--private", |
| action='store_true', |
| help="Set the Hugging Face repository to private." |
| ) |
|
|
| args = parser.parse_args() |
|
|
| print("--- Starting Dataset Upload Script ---") |
|
|
| |
| print("Checking for git-lfs...") |
| check_git_lfs_installed() |
|
|
| |
| token = args.hf_token |
| if not token: |
| token = HfFolder.get_token() |
|
|
| if not token: |
| print("\nAttempting Hugging Face login...") |
| try: |
| login() |
| token = HfFolder.get_token() |
| if not token: |
| raise Exception("Login seemed successful but token could not be retrieved.") |
| except Exception as e: |
| print(f"Error during Hugging Face login: {e}") |
| print("Please ensure you are logged in via 'huggingface-cli login' or provide a token using --hf_token.") |
| sys.exit(1) |
| else: |
| print("Using provided/cached Hugging Face token.") |
| |
|
|
|
|
| |
| dir_name = os.path.basename(os.path.normpath(args.local_dir)) |
|
|
| train_file = args.train_filename if args.train_filename else f"train_{dir_name}.csv" |
| val_file = args.val_filename if args.val_filename else f"val_{dir_name}.csv" |
| test_file = args.test_filename if args.test_filename else f"test_{dir_name}.csv" |
|
|
| print(f"Using directory: {args.local_dir}") |
| print(f"Target Hub repo: {args.repo_id}") |
| print(f"Expecting data column: '{args.data_column}'") |
| print(f"Using train file: '{train_file}'") |
| print(f"Using validation file: '{val_file}'") |
| |
| if args.test_filename or os.path.exists(os.path.join(args.local_dir, test_file)): |
| print(f"Using test file: '{test_file}'") |
| else: |
| print("No test file specified or default test file not found, skipping.") |
| test_file = None |
|
|
|
|
| |
| train_path = os.path.join(args.local_dir, train_file) |
| val_path = os.path.join(args.local_dir, val_file) |
| test_path = os.path.join(args.local_dir, test_file) if test_file else None |
|
|
| data_files = {} |
| if os.path.exists(train_path): |
| data_files["train"] = train_path |
| else: |
| print(f"Error: Training file not found at '{train_path}'") |
| sys.exit(1) |
|
|
| if os.path.exists(val_path): |
| data_files["validation"] = val_path |
| else: |
| print(f"Error: Validation file not found at '{val_path}'") |
| sys.exit(1) |
|
|
| if test_path and os.path.exists(test_path): |
| data_files["test"] = test_path |
| elif args.test_filename: |
| print(f"Warning: Specified test file '{args.test_filename}' not found at '{test_path}'. Skipping test split.") |
|
|
|
|
| |
| print("\nLoading local CSV files...") |
| try: |
| |
| features = Features({args.data_column: Value('string')}) |
| dataset_dict = load_dataset("csv", data_files=data_files, features=features) |
| print("Local dataset loaded successfully:") |
| print(dataset_dict) |
|
|
| |
| for split in dataset_dict: |
| if args.data_column not in dataset_dict[split].column_names: |
| print(f"Error: Column '{args.data_column}' not found in loaded '{split}' split.") |
| print(f"Available columns: {dataset_dict[split].column_names}") |
| sys.exit(1) |
|
|
| except Exception as e: |
| print(f"Error loading dataset from CSV files: {e}") |
| print("Please check file paths, CSV format, and column names.") |
| sys.exit(1) |
|
|
| |
| |
| if args.data_column != 'text': |
| print(f"Renaming column '{args.data_column}' to 'text'...") |
| try: |
| dataset_dict = dataset_dict.rename_column(args.data_column, "text") |
| print("Column renamed successfully.") |
| print(dataset_dict) |
| except Exception as e: |
| print(f"Error renaming column: {e}") |
| |
| |
|
|
|
|
| |
| print(f"\nAttempting to push dataset to Hub repository: {args.repo_id}...") |
| try: |
| dataset_dict.push_to_hub( |
| repo_id=args.repo_id, |
| private=args.private, |
| token=token |
| ) |
| print("\n--- Upload Successful! ---") |
| hub_url = f"https://huggingface.co/datasets/{args.repo_id}" |
| print(f"Dataset available at: {hub_url}") |
|
|
| except Exception as e: |
| print(f"\n--- Error During Upload ---") |
| print(f"An error occurred: {e}") |
| print("Possible causes:") |
| print("- Invalid Hugging Face token or insufficient permissions (needs write access).") |
| print("- Repository ID format incorrect (should be 'username/dataset_name').") |
| print("- Network issues.") |
| print("- Git LFS not installed or properly configured.") |
| print("- Conflicts if the repository already exists with incompatible content.") |
| sys.exit(1) |
|
|
| if __name__ == "__main__": |
| main() |