agentic-defensor / download_data.py
vichudo's picture
add first approach
b840b29
Raw
History Blame Contribute Delete
5.73 kB
#!/usr/bin/env python3
"""
Data downloading script for cloud deployments.
This script downloads necessary data files from cloud storage
for use in cloud deployments like Replicate or Hugging Face Spaces.
"""
import os
import sys
import argparse
import requests
from tqdm import tqdm
import time
import hashlib
# Configuration
DATA_DIRS = ["data", "embeddings", "pdfs"]
REQUIRED_FILES = [
"embeddings/faiss_index.index",
"data/doc_chunks.pkl",
"embeddings/embeddings.pkl"
]
# Replace these URLs with your actual storage URLs
DEFAULT_STORAGE_URLS = {
"embeddings/faiss_index.index": "https://your-storage-url.com/faiss_index.index",
"data/doc_chunks.pkl": "https://your-storage-url.com/doc_chunks.pkl",
"embeddings/embeddings.pkl": "https://your-storage-url.com/embeddings.pkl"
}
# File hashes for verification (sha256)
FILE_HASHES = {
"embeddings/faiss_index.index": "your_hash_here",
"data/doc_chunks.pkl": "your_hash_here",
"embeddings/embeddings.pkl": "your_hash_here"
}
def create_directories():
"""Create necessary directories if they don't exist."""
for directory in DATA_DIRS:
if not os.path.exists(directory):
os.makedirs(directory)
print(f"Created directory: {directory}")
def verify_file(file_path, expected_hash=None):
"""Verify if a file exists and optionally check its hash."""
if not os.path.exists(file_path):
return False
if expected_hash:
try:
with open(file_path, "rb") as f:
file_hash = hashlib.sha256(f.read()).hexdigest()
return file_hash == expected_hash
except Exception as e:
print(f"Error verifying file hash: {e}")
return False
return True
def download_file(url, destination, expected_hash=None):
"""Download a file from a URL showing progress."""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Get file size for progress bar
total_size = int(response.headers.get("content-length", 0))
# Download with progress bar
with open(destination, "wb") as f, tqdm(
desc=os.path.basename(destination),
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as progress:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
progress.update(len(chunk))
# Verify hash if provided
if expected_hash and not verify_file(destination, expected_hash):
print(f"Warning: Hash verification failed for {destination}")
return False
return True
except Exception as e:
print(f"Error downloading {url}: {e}")
if os.path.exists(destination):
os.remove(destination)
return False
def download_missing_files(storage_base_url=None, verify_hashes=True, max_retries=3):
"""Download missing data files."""
create_directories()
# Keep track of missing files
missing_files = []
for file_path in REQUIRED_FILES:
file_hash = FILE_HASHES.get(file_path) if verify_hashes else None
if not verify_file(file_path, file_hash):
missing_files.append(file_path)
if not missing_files:
print("All required files are present and valid.")
return True
print(f"Missing or invalid files: {len(missing_files)}")
# Download missing files
success = True
for file_path in missing_files:
# Determine download URL
if storage_base_url:
url = f"{storage_base_url}/{file_path}"
else:
url = DEFAULT_STORAGE_URLS.get(file_path)
if not url:
print(f"Error: No URL configured for {file_path}")
success = False
continue
print(f"Downloading {file_path}...")
# Try with retries
for attempt in range(max_retries):
if download_file(url, file_path, FILE_HASHES.get(file_path) if verify_hashes else None):
print(f"Successfully downloaded {file_path}")
break
elif attempt < max_retries - 1:
retry_delay = 2 ** attempt # Exponential backoff
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print(f"Failed to download {file_path} after {max_retries} attempts")
success = False
return success
def main():
parser = argparse.ArgumentParser(description="Download data files for Agentic Defensor")
parser.add_argument("--storage-url", type=str, help="Base URL for data storage")
parser.add_argument("--skip-verify", action="store_true", help="Skip hash verification")
parser.add_argument("--force", action="store_true", help="Force redownload all files")
args = parser.parse_args()
if args.force:
# Remove existing files
for file_path in REQUIRED_FILES:
if os.path.exists(file_path):
os.remove(file_path)
print(f"Removed existing file: {file_path}")
# Download missing files
success = download_missing_files(
storage_base_url=args.storage_url,
verify_hashes=not args.skip_verify
)
if success:
print("All data files are ready for use!")
return 0
else:
print("Some files could not be downloaded. Check the logs for details.")
return 1
if __name__ == "__main__":
sys.exit(main())