Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Data downloading script for cloud deployments. | |
| This script downloads necessary data files from cloud storage | |
| for use in cloud deployments like Replicate or Hugging Face Spaces. | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import requests | |
| from tqdm import tqdm | |
| import time | |
| import hashlib | |
| # Configuration | |
| DATA_DIRS = ["data", "embeddings", "pdfs"] | |
| REQUIRED_FILES = [ | |
| "embeddings/faiss_index.index", | |
| "data/doc_chunks.pkl", | |
| "embeddings/embeddings.pkl" | |
| ] | |
| # Replace these URLs with your actual storage URLs | |
| DEFAULT_STORAGE_URLS = { | |
| "embeddings/faiss_index.index": "https://your-storage-url.com/faiss_index.index", | |
| "data/doc_chunks.pkl": "https://your-storage-url.com/doc_chunks.pkl", | |
| "embeddings/embeddings.pkl": "https://your-storage-url.com/embeddings.pkl" | |
| } | |
| # File hashes for verification (sha256) | |
| FILE_HASHES = { | |
| "embeddings/faiss_index.index": "your_hash_here", | |
| "data/doc_chunks.pkl": "your_hash_here", | |
| "embeddings/embeddings.pkl": "your_hash_here" | |
| } | |
| def create_directories(): | |
| """Create necessary directories if they don't exist.""" | |
| for directory in DATA_DIRS: | |
| if not os.path.exists(directory): | |
| os.makedirs(directory) | |
| print(f"Created directory: {directory}") | |
| def verify_file(file_path, expected_hash=None): | |
| """Verify if a file exists and optionally check its hash.""" | |
| if not os.path.exists(file_path): | |
| return False | |
| if expected_hash: | |
| try: | |
| with open(file_path, "rb") as f: | |
| file_hash = hashlib.sha256(f.read()).hexdigest() | |
| return file_hash == expected_hash | |
| except Exception as e: | |
| print(f"Error verifying file hash: {e}") | |
| return False | |
| return True | |
| def download_file(url, destination, expected_hash=None): | |
| """Download a file from a URL showing progress.""" | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| # Get file size for progress bar | |
| total_size = int(response.headers.get("content-length", 0)) | |
| # Download with progress bar | |
| with open(destination, "wb") as f, tqdm( | |
| desc=os.path.basename(destination), | |
| total=total_size, | |
| unit="B", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as progress: | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: | |
| f.write(chunk) | |
| progress.update(len(chunk)) | |
| # Verify hash if provided | |
| if expected_hash and not verify_file(destination, expected_hash): | |
| print(f"Warning: Hash verification failed for {destination}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading {url}: {e}") | |
| if os.path.exists(destination): | |
| os.remove(destination) | |
| return False | |
| def download_missing_files(storage_base_url=None, verify_hashes=True, max_retries=3): | |
| """Download missing data files.""" | |
| create_directories() | |
| # Keep track of missing files | |
| missing_files = [] | |
| for file_path in REQUIRED_FILES: | |
| file_hash = FILE_HASHES.get(file_path) if verify_hashes else None | |
| if not verify_file(file_path, file_hash): | |
| missing_files.append(file_path) | |
| if not missing_files: | |
| print("All required files are present and valid.") | |
| return True | |
| print(f"Missing or invalid files: {len(missing_files)}") | |
| # Download missing files | |
| success = True | |
| for file_path in missing_files: | |
| # Determine download URL | |
| if storage_base_url: | |
| url = f"{storage_base_url}/{file_path}" | |
| else: | |
| url = DEFAULT_STORAGE_URLS.get(file_path) | |
| if not url: | |
| print(f"Error: No URL configured for {file_path}") | |
| success = False | |
| continue | |
| print(f"Downloading {file_path}...") | |
| # Try with retries | |
| for attempt in range(max_retries): | |
| if download_file(url, file_path, FILE_HASHES.get(file_path) if verify_hashes else None): | |
| print(f"Successfully downloaded {file_path}") | |
| break | |
| elif attempt < max_retries - 1: | |
| retry_delay = 2 ** attempt # Exponential backoff | |
| print(f"Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| print(f"Failed to download {file_path} after {max_retries} attempts") | |
| success = False | |
| return success | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Download data files for Agentic Defensor") | |
| parser.add_argument("--storage-url", type=str, help="Base URL for data storage") | |
| parser.add_argument("--skip-verify", action="store_true", help="Skip hash verification") | |
| parser.add_argument("--force", action="store_true", help="Force redownload all files") | |
| args = parser.parse_args() | |
| if args.force: | |
| # Remove existing files | |
| for file_path in REQUIRED_FILES: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| print(f"Removed existing file: {file_path}") | |
| # Download missing files | |
| success = download_missing_files( | |
| storage_base_url=args.storage_url, | |
| verify_hashes=not args.skip_verify | |
| ) | |
| if success: | |
| print("All data files are ready for use!") | |
| return 0 | |
| else: | |
| print("Some files could not be downloaded. Check the logs for details.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |