File size: 6,267 Bytes
9a71b8f 495dc7c 9a71b8f 3632723 9a71b8f 3632723 9a71b8f 3632723 495dc7c 9a71b8f 3632723 9a71b8f 3632723 9a71b8f 3632723 9a71b8f 3632723 9a71b8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
"""
Dataset Setup Script - Download data from Hugging Face instead of storing in git
"""
import os
import tempfile
from pathlib import Path
from huggingface_hub import snapshot_download
import shutil
def download_datasets(data_dir: Path = Path("data"), force_retry: bool = False):
"""Download all datasets from Hugging Face to local data directory
Args:
data_dir: Directory to download data to
force_retry: Force retry download even if data exists
"""
print("π Setting up AI Due Diligence datasets from Hugging Face...")
# Ensure data directory exists
data_dir.mkdir(exist_ok=True)
# Check if we already have data and skip unless forced
if not force_retry:
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file())
if file_count > 100: # Reasonable threshold for "has data"
print(f"β
Data directory already contains {file_count} files, skipping download")
print(f" Use --force to re-download anyway")
return
datasets = [
{
"repo_id": "jmzlx/dd-framework",
"description": "Methodology and templates",
"local_path": data_dir,
"files": ["data/checklist/*", "data/questions/*", "data/strategy/*"]
},
{
"repo_id": "jmzlx/dd-indexes",
"description": "Search indexes and ML artifacts",
"local_path": data_dir,
"files": ["data/search_indexes/*"]
},
{
"repo_id": "jmzlx/dd-vdrs",
"description": "Virtual data room documents",
"local_path": data_dir,
"files": ["data/vdrs/*"]
}
]
for dataset in datasets:
print(f"\nπ Downloading {dataset['description']}...")
print(f" Repository: {dataset['repo_id']}")
try:
# Use HF_TOKEN if available for private repos
token = os.getenv("HF_TOKEN")
if token:
print(f" π Using HuggingFace token for authentication")
else:
print(f" β οΈ No HF_TOKEN found - may fail for private repositories")
# Download to temporary directory first to handle nested structure
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
snapshot_download(
repo_id=dataset["repo_id"],
repo_type="dataset",
local_dir=temp_path,
allow_patterns="data/**", # Download data directory
token=token # Pass token if available
)
# Move contents from temp_dir/data to target data_dir
temp_data_dir = temp_path / "data"
if temp_data_dir.exists():
for item in temp_data_dir.iterdir():
target_path = dataset["local_path"] / item.name
if target_path.exists():
shutil.rmtree(target_path) if target_path.is_dir() else target_path.unlink()
shutil.move(str(item), str(target_path))
print(f" β
Downloaded successfully")
except Exception as e:
print(f" β Error downloading {dataset['repo_id']}: {type(e).__name__}: {e}")
if "401" in str(e) or "403" in str(e) or "private" in str(e).lower():
print(f" π This appears to be a private repository requiring authentication")
if not token:
print(f" π‘ Set HF_TOKEN environment variable with read access to this repository")
else:
print(f" π‘ Check that your HF_TOKEN has access to this repository")
elif "network" in str(e).lower() or "connection" in str(e).lower():
print(f" π Network connectivity issue - check internet connection")
print(f" π‘ Manual download: https://huggingface.co/datasets/{dataset['repo_id']}")
# Continue with other datasets even if one fails
continue
print(f"\nπ Dataset setup complete! Data available in: {data_dir.absolute()}")
# Show what was downloaded
total_size = sum(f.stat().st_size for f in data_dir.rglob("*") if f.is_file())
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file())
print(f"π Downloaded {file_count:,} files, {total_size/(1024*1024):.1f}MB total")
# Check if we're in a deployment environment and provide guidance
is_deployment = os.getenv('STREAMLIT_SERVER_HEADLESS') == 'true' or os.getenv('HF_HOME') == '/tmp/huggingface'
if is_deployment:
print("\nπ Deployment Environment Detected")
if file_count < 50:
print("β οΈ Few files downloaded - this may indicate missing authentication")
print("π‘ Ensure HF_TOKEN is set in your deployment environment secrets")
print("π‘ Token should have read access to private repositories: jmzlx/dd-framework, jmzlx/dd-indexes, jmzlx/dd-vdrs")
else:
print("β
Data download appears successful for deployment environment")
def clean_old_data(data_dir: Path = Path("data")):
"""Remove old data directory (use with caution!)"""
if data_dir.exists():
print(f"ποΈ Removing old data directory: {data_dir}")
shutil.rmtree(data_dir)
print("β
Old data removed")
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description="Setup datasets from Hugging Face")
parser.add_argument("--clean", action="store_true", help="Remove existing data directory first")
parser.add_argument("--force", action="store_true", help="Force re-download even if data exists")
parser.add_argument("--data-dir", default="data", help="Data directory path")
args = parser.parse_args()
data_dir = Path(args.data_dir)
if args.clean:
clean_old_data(data_dir)
download_datasets(data_dir, force_retry=args.force)
if __name__ == "__main__":
main()
|