augustocsc's picture
GPT-2 Medium trained on prefix dataset (682K)
3742716 verified
# upload_dataset_to_hf.py
import argparse
import os
import sys
import subprocess
from datasets import load_dataset, DatasetDict, Features, Value
from huggingface_hub import HfApi, HfFolder, login, HfApi
# Added import for HfFolder
# --- Helper Function to Check Git LFS ---
def check_git_lfs_installed():
"""Checks if git-lfs is installed and configured."""
try:
# Check if git-lfs command exists
subprocess.run(["git", "lfs", "--version"], check=True, capture_output=True)
# Check if git-lfs is initialized for the user (optional but good practice)
# This command might vary or not be strictly necessary depending on setup
# subprocess.run(["git", "config", "--global", "--get", "filter.lfs.smudge"], check=True, capture_output=True)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
print("Warning: git-lfs command not found or not configured.")
print(" Please install git-lfs and run 'git lfs install --system' (or --user).")
print(" See: https://git-lfs.com/")
# Optionally exit if git-lfs is strictly required
# sys.exit(1)
return False # Allow script to continue but warn user
# --- Main Script Logic ---
def main():
parser = argparse.ArgumentParser(
description="Upload CSV dataset splits from a local directory to the Hugging Face Hub."
)
# --- Required Arguments ---
parser.add_argument(
"--local_dir",
type=str,
required=True,
help="Path to the local directory containing the dataset CSV files."
)
parser.add_argument(
"--repo_id",
type=str,
required=True,
help="The Hugging Face Hub repository ID (e.g., 'username/my-equation-dataset')."
)
parser.add_argument(
"--data_column",
type=str,
required=True,
help="Name of the column in the CSV files containing the actual data (e.g., 'text', 'equation')."
)
# --- Optional Arguments ---
parser.add_argument(
"--train_filename",
type=str,
default=None,
help="Filename of the training CSV within local_dir (e.g., 'train_data.csv')."
)
parser.add_argument(
"--val_filename",
type=str,
default=None,
help="Filename of the validation CSV within local_dir (e.g., 'validation_set.csv')."
)
parser.add_argument(
"--test_filename",
type=str,
default=None,
help="Filename of the test CSV within local_dir (optional, e.g., 'test_examples.csv')."
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Your Hugging Face Hub access token (with write permissions). If not provided, script will try to use cached token or prompt login."
)
parser.add_argument(
"--private",
action='store_true', # Makes the repo private if flag is present
help="Set the Hugging Face repository to private."
)
args = parser.parse_args()
print("--- Starting Dataset Upload Script ---")
# 1. Check Git LFS
print("Checking for git-lfs...")
check_git_lfs_installed() # Warns if not found
# 2. Handle Authentication
token = args.hf_token
if not token:
token = HfFolder.get_token() # Try to get cached token
if not token:
print("\nAttempting Hugging Face login...")
try:
login() # Will prompt user if not logged in via CLI
token = HfFolder.get_token() # Get token after successful login
if not token:
raise Exception("Login seemed successful but token could not be retrieved.")
except Exception as e:
print(f"Error during Hugging Face login: {e}")
print("Please ensure you are logged in via 'huggingface-cli login' or provide a token using --hf_token.")
sys.exit(1)
else:
print("Using provided/cached Hugging Face token.")
# Optionally verify token validity here if needed, though push_to_hub will fail if invalid
# 3. Determine Filenames
dir_name = os.path.basename(os.path.normpath(args.local_dir)) # Gets the last part of the path
train_file = args.train_filename if args.train_filename else f"train_{dir_name}.csv"
val_file = args.val_filename if args.val_filename else f"val_{dir_name}.csv" # Using 'val' as abbreviation
test_file = args.test_filename if args.test_filename else f"test_{dir_name}.csv"
print(f"Using directory: {args.local_dir}")
print(f"Target Hub repo: {args.repo_id}")
print(f"Expecting data column: '{args.data_column}'")
print(f"Using train file: '{train_file}'")
print(f"Using validation file: '{val_file}'")
# Test file is optional, only check if default or specific name provided
if args.test_filename or os.path.exists(os.path.join(args.local_dir, test_file)):
print(f"Using test file: '{test_file}'")
else:
print("No test file specified or default test file not found, skipping.")
test_file = None # Ensure test_file is None if not used
# 4. Construct Full Paths and Check Existence
train_path = os.path.join(args.local_dir, train_file)
val_path = os.path.join(args.local_dir, val_file)
test_path = os.path.join(args.local_dir, test_file) if test_file else None
data_files = {}
if os.path.exists(train_path):
data_files["train"] = train_path
else:
print(f"Error: Training file not found at '{train_path}'")
sys.exit(1)
if os.path.exists(val_path):
data_files["validation"] = val_path
else:
print(f"Error: Validation file not found at '{val_path}'")
sys.exit(1)
if test_path and os.path.exists(test_path):
data_files["test"] = test_path
elif args.test_filename: # If user specified a test file but it wasn't found
print(f"Warning: Specified test file '{args.test_filename}' not found at '{test_path}'. Skipping test split.")
# 5. Load Dataset Locally
print("\nLoading local CSV files...")
try:
# Define features to ensure the data column is read as string
features = Features({args.data_column: Value('string')})
dataset_dict = load_dataset("csv", data_files=data_files, features=features)
print("Local dataset loaded successfully:")
print(dataset_dict)
# Verify the data column exists in the loaded dataset
for split in dataset_dict:
if args.data_column not in dataset_dict[split].column_names:
print(f"Error: Column '{args.data_column}' not found in loaded '{split}' split.")
print(f"Available columns: {dataset_dict[split].column_names}")
sys.exit(1)
except Exception as e:
print(f"Error loading dataset from CSV files: {e}")
print("Please check file paths, CSV format, and column names.")
sys.exit(1)
# 6. Rename column if necessary (optional, often good to standardize to 'text')
# If you always want the main data column to be named 'text' on the Hub:
if args.data_column != 'text':
print(f"Renaming column '{args.data_column}' to 'text'...")
try:
dataset_dict = dataset_dict.rename_column(args.data_column, "text")
print("Column renamed successfully.")
print(dataset_dict)
except Exception as e:
print(f"Error renaming column: {e}")
# Decide if you want to exit or proceed with the original column name
# sys.exit(1)
# 7. Push to Hub
print(f"\nAttempting to push dataset to Hub repository: {args.repo_id}...")
try:
dataset_dict.push_to_hub(
repo_id=args.repo_id,
private=args.private,
token=token # Pass token explicitly
)
print("\n--- Upload Successful! ---")
hub_url = f"https://huggingface.co/datasets/{args.repo_id}"
print(f"Dataset available at: {hub_url}")
except Exception as e:
print(f"\n--- Error During Upload ---")
print(f"An error occurred: {e}")
print("Possible causes:")
print("- Invalid Hugging Face token or insufficient permissions (needs write access).")
print("- Repository ID format incorrect (should be 'username/dataset_name').")
print("- Network issues.")
print("- Git LFS not installed or properly configured.")
print("- Conflicts if the repository already exists with incompatible content.")
sys.exit(1)
if __name__ == "__main__":
main()