dd-poc / scripts /setup_datasets.py
Juan Salas
Fix project selector directory structure issues
495dc7c
#!/usr/bin/env python3
"""
Dataset Setup Script - Download data from Hugging Face instead of storing in git
"""
import os
import tempfile
from pathlib import Path
from huggingface_hub import snapshot_download
import shutil
def download_datasets(data_dir: Path = Path("data"), force_retry: bool = False):
"""Download all datasets from Hugging Face to local data directory
Args:
data_dir: Directory to download data to
force_retry: Force retry download even if data exists
"""
print("πŸš€ Setting up AI Due Diligence datasets from Hugging Face...")
# Ensure data directory exists
data_dir.mkdir(exist_ok=True)
# Check if we already have data and skip unless forced
if not force_retry:
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file())
if file_count > 100: # Reasonable threshold for "has data"
print(f"βœ… Data directory already contains {file_count} files, skipping download")
print(f" Use --force to re-download anyway")
return
datasets = [
{
"repo_id": "jmzlx/dd-framework",
"description": "Methodology and templates",
"local_path": data_dir,
"files": ["data/checklist/*", "data/questions/*", "data/strategy/*"]
},
{
"repo_id": "jmzlx/dd-indexes",
"description": "Search indexes and ML artifacts",
"local_path": data_dir,
"files": ["data/search_indexes/*"]
},
{
"repo_id": "jmzlx/dd-vdrs",
"description": "Virtual data room documents",
"local_path": data_dir,
"files": ["data/vdrs/*"]
}
]
for dataset in datasets:
print(f"\nπŸ“‹ Downloading {dataset['description']}...")
print(f" Repository: {dataset['repo_id']}")
try:
# Use HF_TOKEN if available for private repos
token = os.getenv("HF_TOKEN")
if token:
print(f" πŸ”‘ Using HuggingFace token for authentication")
else:
print(f" ⚠️ No HF_TOKEN found - may fail for private repositories")
# Download to temporary directory first to handle nested structure
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
snapshot_download(
repo_id=dataset["repo_id"],
repo_type="dataset",
local_dir=temp_path,
allow_patterns="data/**", # Download data directory
token=token # Pass token if available
)
# Move contents from temp_dir/data to target data_dir
temp_data_dir = temp_path / "data"
if temp_data_dir.exists():
for item in temp_data_dir.iterdir():
target_path = dataset["local_path"] / item.name
if target_path.exists():
shutil.rmtree(target_path) if target_path.is_dir() else target_path.unlink()
shutil.move(str(item), str(target_path))
print(f" βœ… Downloaded successfully")
except Exception as e:
print(f" ❌ Error downloading {dataset['repo_id']}: {type(e).__name__}: {e}")
if "401" in str(e) or "403" in str(e) or "private" in str(e).lower():
print(f" πŸ”’ This appears to be a private repository requiring authentication")
if not token:
print(f" πŸ’‘ Set HF_TOKEN environment variable with read access to this repository")
else:
print(f" πŸ’‘ Check that your HF_TOKEN has access to this repository")
elif "network" in str(e).lower() or "connection" in str(e).lower():
print(f" 🌐 Network connectivity issue - check internet connection")
print(f" πŸ’‘ Manual download: https://huggingface.co/datasets/{dataset['repo_id']}")
# Continue with other datasets even if one fails
continue
print(f"\nπŸŽ‰ Dataset setup complete! Data available in: {data_dir.absolute()}")
# Show what was downloaded
total_size = sum(f.stat().st_size for f in data_dir.rglob("*") if f.is_file())
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file())
print(f"πŸ“Š Downloaded {file_count:,} files, {total_size/(1024*1024):.1f}MB total")
# Check if we're in a deployment environment and provide guidance
is_deployment = os.getenv('STREAMLIT_SERVER_HEADLESS') == 'true' or os.getenv('HF_HOME') == '/tmp/huggingface'
if is_deployment:
print("\nπŸš€ Deployment Environment Detected")
if file_count < 50:
print("⚠️ Few files downloaded - this may indicate missing authentication")
print("πŸ’‘ Ensure HF_TOKEN is set in your deployment environment secrets")
print("πŸ’‘ Token should have read access to private repositories: jmzlx/dd-framework, jmzlx/dd-indexes, jmzlx/dd-vdrs")
else:
print("βœ… Data download appears successful for deployment environment")
def clean_old_data(data_dir: Path = Path("data")):
"""Remove old data directory (use with caution!)"""
if data_dir.exists():
print(f"πŸ—‘οΈ Removing old data directory: {data_dir}")
shutil.rmtree(data_dir)
print("βœ… Old data removed")
def main():
"""Main function"""
import argparse
parser = argparse.ArgumentParser(description="Setup datasets from Hugging Face")
parser.add_argument("--clean", action="store_true", help="Remove existing data directory first")
parser.add_argument("--force", action="store_true", help="Force re-download even if data exists")
parser.add_argument("--data-dir", default="data", help="Data directory path")
args = parser.parse_args()
data_dir = Path(args.data_dir)
if args.clean:
clean_old_data(data_dir)
download_datasets(data_dir, force_retry=args.force)
if __name__ == "__main__":
main()