|
|
|
|
|
""" |
|
|
Dataset Setup Script - Download data from Hugging Face instead of storing in git |
|
|
""" |
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
import shutil |
|
|
|
|
|
def download_datasets(data_dir: Path = Path("data"), force_retry: bool = False): |
|
|
"""Download all datasets from Hugging Face to local data directory |
|
|
|
|
|
Args: |
|
|
data_dir: Directory to download data to |
|
|
force_retry: Force retry download even if data exists |
|
|
""" |
|
|
|
|
|
print("π Setting up AI Due Diligence datasets from Hugging Face...") |
|
|
|
|
|
|
|
|
data_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
if not force_retry: |
|
|
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file()) |
|
|
if file_count > 100: |
|
|
print(f"β
Data directory already contains {file_count} files, skipping download") |
|
|
print(f" Use --force to re-download anyway") |
|
|
return |
|
|
|
|
|
datasets = [ |
|
|
{ |
|
|
"repo_id": "jmzlx/dd-framework", |
|
|
"description": "Methodology and templates", |
|
|
"local_path": data_dir, |
|
|
"files": ["data/checklist/*", "data/questions/*", "data/strategy/*"] |
|
|
}, |
|
|
{ |
|
|
"repo_id": "jmzlx/dd-indexes", |
|
|
"description": "Search indexes and ML artifacts", |
|
|
"local_path": data_dir, |
|
|
"files": ["data/search_indexes/*"] |
|
|
}, |
|
|
{ |
|
|
"repo_id": "jmzlx/dd-vdrs", |
|
|
"description": "Virtual data room documents", |
|
|
"local_path": data_dir, |
|
|
"files": ["data/vdrs/*"] |
|
|
} |
|
|
] |
|
|
|
|
|
for dataset in datasets: |
|
|
print(f"\nπ Downloading {dataset['description']}...") |
|
|
print(f" Repository: {dataset['repo_id']}") |
|
|
|
|
|
try: |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
|
if token: |
|
|
print(f" π Using HuggingFace token for authentication") |
|
|
else: |
|
|
print(f" β οΈ No HF_TOKEN found - may fail for private repositories") |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
temp_path = Path(temp_dir) |
|
|
snapshot_download( |
|
|
repo_id=dataset["repo_id"], |
|
|
repo_type="dataset", |
|
|
local_dir=temp_path, |
|
|
allow_patterns="data/**", |
|
|
token=token |
|
|
) |
|
|
|
|
|
|
|
|
temp_data_dir = temp_path / "data" |
|
|
if temp_data_dir.exists(): |
|
|
for item in temp_data_dir.iterdir(): |
|
|
target_path = dataset["local_path"] / item.name |
|
|
if target_path.exists(): |
|
|
shutil.rmtree(target_path) if target_path.is_dir() else target_path.unlink() |
|
|
shutil.move(str(item), str(target_path)) |
|
|
print(f" β
Downloaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β Error downloading {dataset['repo_id']}: {type(e).__name__}: {e}") |
|
|
if "401" in str(e) or "403" in str(e) or "private" in str(e).lower(): |
|
|
print(f" π This appears to be a private repository requiring authentication") |
|
|
if not token: |
|
|
print(f" π‘ Set HF_TOKEN environment variable with read access to this repository") |
|
|
else: |
|
|
print(f" π‘ Check that your HF_TOKEN has access to this repository") |
|
|
elif "network" in str(e).lower() or "connection" in str(e).lower(): |
|
|
print(f" π Network connectivity issue - check internet connection") |
|
|
print(f" π‘ Manual download: https://huggingface.co/datasets/{dataset['repo_id']}") |
|
|
|
|
|
|
|
|
continue |
|
|
|
|
|
print(f"\nπ Dataset setup complete! Data available in: {data_dir.absolute()}") |
|
|
|
|
|
|
|
|
total_size = sum(f.stat().st_size for f in data_dir.rglob("*") if f.is_file()) |
|
|
file_count = sum(1 for f in data_dir.rglob("*") if f.is_file()) |
|
|
|
|
|
print(f"π Downloaded {file_count:,} files, {total_size/(1024*1024):.1f}MB total") |
|
|
|
|
|
|
|
|
is_deployment = os.getenv('STREAMLIT_SERVER_HEADLESS') == 'true' or os.getenv('HF_HOME') == '/tmp/huggingface' |
|
|
if is_deployment: |
|
|
print("\nπ Deployment Environment Detected") |
|
|
if file_count < 50: |
|
|
print("β οΈ Few files downloaded - this may indicate missing authentication") |
|
|
print("π‘ Ensure HF_TOKEN is set in your deployment environment secrets") |
|
|
print("π‘ Token should have read access to private repositories: jmzlx/dd-framework, jmzlx/dd-indexes, jmzlx/dd-vdrs") |
|
|
else: |
|
|
print("β
Data download appears successful for deployment environment") |
|
|
|
|
|
def clean_old_data(data_dir: Path = Path("data")): |
|
|
"""Remove old data directory (use with caution!)""" |
|
|
if data_dir.exists(): |
|
|
print(f"ποΈ Removing old data directory: {data_dir}") |
|
|
shutil.rmtree(data_dir) |
|
|
print("β
Old data removed") |
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Setup datasets from Hugging Face") |
|
|
parser.add_argument("--clean", action="store_true", help="Remove existing data directory first") |
|
|
parser.add_argument("--force", action="store_true", help="Force re-download even if data exists") |
|
|
parser.add_argument("--data-dir", default="data", help="Data directory path") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
data_dir = Path(args.data_dir) |
|
|
|
|
|
if args.clean: |
|
|
clean_old_data(data_dir) |
|
|
|
|
|
download_datasets(data_dir, force_retry=args.force) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|