#!/usr/bin/env python3 """ Enterprise-grade model downloader and extractor for NLProxy. This CLI utility ensures that all required ONNX models are available locally for offline inference. It follows a strict, idempotent, and atomic workflow: 1. Verifies/creates the target `models/` directory. 2. Checks if `nlproxy_models.zip` exists inside it. 3. If missing, downloads the ZIP from a configurable URL with exponential backoff. 4. Extracts the archive safely (prevents path traversal, validates integrity). 5. Validates that all 3 expected model directories are present and non-empty. 6. Cleans up temporary files and reports status. Usage ----- # Basic usage (uses NLPROXY_MODELS_URL env var or --url flag) $ python -m nlproxy download_models Configuration ------------- Environment variables: NLPROXY_MODELS_URL Direct download URL for nlproxy_models.zip NLPROXY_LOG_LEVEL Logging level: DEBUG, INFO, WARNING, ERROR NLPROXY_PROXY HTTP/HTTPS proxy URL (optional) Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import argparse import logging import os import shutil import sys import tempfile import time import zipfile from pathlib import Path from typing import List, Optional import requests from requests.exceptions import RequestException from nlproxy.core.model_manager import ModelManager try: from tqdm import tqdm _TQDM_AVAILABLE = True except ImportError: _TQDM_AVAILABLE = False tqdm = None # type: ignore logger = logging.getLogger(__name__) # ============================================================================= # CONFIGURATION CONSTANTS # ============================================================================= DEFAULT_MODELS_DIR: Path = Path("nlproxy") / "models" ZIP_FILENAME: str = "nlproxy_models.zip" # Leave empty. Override via --url CLI flag or NLPROXY_MODELS_URL env var. DEFAULT_MODELS_URL: str = os.getenv("NLPROXY_MODELS_URL", "https://github.com/intellideep/nlproxy/releases/download/free_models/nlproxy_models.zip") # Expected subdirectories after extraction. Used for strict validation. EXPECTED_MODEL_DIRS: List[str] = [ "all-MiniLM-L6-v2", "nli-distilroberta-base", "distilgpt2", ] # Network & Retry Configuration DOWNLOAD_TIMEOUT: int = 600 # 10 minutes MAX_RETRIES: int = 3 RETRY_BACKOFF_BASE: float = 2.0 CHUNK_SIZE: int = 8192 TEMP_PREFIX: str = "nlproxy_dl_" # ============================================================================= # UTILITIES & LOGGING # ============================================================================= def setup_logging(level: str = "INFO") -> None: """Configure structured console logging.""" numeric_level = getattr(logging, level.upper(), logging.INFO) logging.basicConfig( level=numeric_level, format="%(asctime)s [%(levelname)-8s] %(message)s", datefmt="%H:%M:%S", stream=sys.stdout, ) def log_progress(description: str, total: int, iterable): """Wraps an iterable with a progress bar if tqdm is available.""" if _TQDM_AVAILABLE and tqdm: return tqdm( iterable, total=total, desc=description, unit="B", unit_scale=True, unit_divisor=1024 ) # Fallback: periodic logging every 10% chunk_count = 0 last_log = 0.0 for item in iterable: yield item chunk_count += 1 progress = (chunk_count * CHUNK_SIZE / total * 100) if total > 0 else 0.0 if progress - last_log >= 10.0: logger.info(f" {description}: {progress:.1f}%") last_log = progress def error_exit(message: str, code: int = 1) -> None: """Print error message and exit with code.""" logger.error(message) sys.exit(code) def success_exit(message: str) -> None: """Print success message and exit cleanly.""" logger.info(message) sys.exit(0) # ============================================================================= # CORE WORKFLOW FUNCTIONS # ============================================================================= def ensure_models_dir(models_dir: Path) -> None: """Create models directory if it doesn't exist.""" if not models_dir.exists(): logger.info(f"Creating models directory: {models_dir}") models_dir.mkdir(parents=True, exist_ok=True) elif not models_dir.is_dir(): error_exit(f"Path exists but is not a directory: {models_dir}") def download_with_retries(url: str, dest_path: Path) -> bool: """ Download a file from URL with exponential backoff retries. Saves to a temporary file first, then atomically moves to dest_path. """ if not url: raise ValueError( "Download URL is not set. Provide via --url flag or NLPROXY_MODELS_URL env var." ) logger.info(f"Starting download: {url}") for attempt in range(1, MAX_RETRIES + 1): tmp_path = None try: # requests maneja los redirects de GitHub a S3 de forma nativa with requests.get( url, stream=True, timeout=DOWNLOAD_TIMEOUT, headers={"User-Agent": "NLProxy-CLI/1.0"}, allow_redirects=True ) as response: response.raise_for_status() total_size = int(response.headers.get("Content-Length", 0)) # Write to temp file first for atomicity with tempfile.NamedTemporaryFile( delete=False, prefix=TEMP_PREFIX, dir=dest_path.parent ) as tmp_file: tmp_path = Path(tmp_file.name) # 🚀 BULLETPROOF STREAMING: Leer directamente del socket (response.raw) if _TQDM_AVAILABLE and tqdm and total_size > 0: with tqdm( total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"Attempt {attempt}/{MAX_RETRIES}" ) as bar: # Wrapper para actualizar la barra de progreso mientras se copia class TqdmReader: def __init__(self, raw, bar): self.raw = raw self.bar = bar def read(self, size=-1): data = self.raw.read(size) self.bar.update(len(data)) return data # shutil.copyfileobj es la forma nativa y más rápida de copiar streams shutil.copyfileobj(TqdmReader(response.raw, bar), tmp_file) else: # Fallback sin barra de progreso shutil.copyfileobj(response.raw, tmp_file) # Verificar tamaño final downloaded = tmp_path.stat().st_size if total_size > 0 and downloaded < total_size: raise RuntimeError(f"Download incomplete: {downloaded}/{total_size} bytes") # Atomic move to final destination tmp_path.replace(dest_path) logger.info(f"Download complete: {dest_path} ({downloaded / 1024 / 1024:.2f} MB)") return True except (RequestException, OSError, RuntimeError) as e: logger.warning(f"Download failed (attempt {attempt}/{MAX_RETRIES}): {e}") if tmp_path and Path(tmp_path).exists(): Path(tmp_path).unlink(missing_ok=True) if attempt < MAX_RETRIES: backoff = RETRY_BACKOFF_BASE ** (attempt - 1) logger.info(f"Retrying in {backoff:.1f}s...") time.sleep(backoff) else: logger.error("Max retries exceeded. Download aborted.") return False except Exception as e: logger.error(f"Unexpected error during download: {e}") if tmp_path and Path(tmp_path).exists(): Path(tmp_path).unlink(missing_ok=True) return False def extract_and_validate(zip_path: Path, models_dir: Path, expected_dirs: List[str]) -> bool: """ Extract ZIP archive and validate directory structure. Performs integrity check and safe extraction (prevents path traversal). """ logger.info(f"Extracting {zip_path.name} to {models_dir}...") try: with zipfile.ZipFile(zip_path, 'r') as zf: # 1. Integrity check bad_file = zf.testzip() if bad_file is not None: raise zipfile.BadZipFile(f"Archive corruption detected in: {bad_file}") # 2. Safe extraction (prevent path traversal attacks) for member in zf.namelist(): member_path = Path(member) if member_path.is_absolute() or ".." in member_path.parts: raise zipfile.BadZipFile(f"Unsafe path in archive: {member}") zf.extractall(models_dir) logger.info("Extraction successful.") # 3. Structural validation missing = [d for d in expected_dirs if not (models_dir / d).is_dir()] if missing: logger.error(f"Validation failed. Missing directories: {missing}") return False # Verify each directory contains at least one valid model file for dir_name in expected_dirs: dir_path = models_dir / dir_name contents = list(dir_path.iterdir()) if not any(f.suffix in {".onnx", ".bin", ".json", ".txt", ".py"} for f in contents): logger.error(f"Directory appears empty or invalid: {dir_name}") return False logger.info("✅ Model structure validated successfully.") return True except zipfile.BadZipFile as e: logger.error(f"Invalid or corrupted ZIP file: {e}") return False except PermissionError as e: logger.error(f"Permission denied during extraction: {e}") return False except Exception as e: logger.error(f"Extraction failed: {e}") return False def cleanup_temp(zip_path: Path, keep_zip: bool = False) -> None: """Remove temporary/extracted files if not requested to keep them.""" if not keep_zip and zip_path.exists(): try: zip_path.unlink() logger.debug(f"Cleaned up ZIP archive: {zip_path}") except OSError as e: logger.warning(f"Failed to clean up {zip_path}: {e}") # ============================================================================= # CLI ENTRY POINT # ============================================================================= def create_parser() -> argparse.ArgumentParser: """Build CLI argument parser.""" parser = argparse.ArgumentParser( prog="nlproxy-download-models", description="Enterprise-grade model downloader for NLProxy SDK.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Environment Variables: NLPROXY_MODELS_URL Direct URL to nlproxy_models.zip NLPROXY_LOG_LEVEL DEBUG, INFO, WARNING, ERROR (default: INFO) NLPROXY_PROXY HTTP/HTTPS proxy URL (optional) Examples: python -m nlproxy download_models python -m nlproxy download_models --force --url https://example.com/models.zip python -m nlproxy download_models --verify-only --models-dir /opt/models """, ) parser.add_argument( "--models-dir", type=Path, default=DEFAULT_MODELS_DIR, help="Target directory for extracted models (default: ./models)" ) parser.add_argument( "--url", type=str, default=DEFAULT_MODELS_URL, help="Direct download URL for nlproxy_models.zip" ) parser.add_argument( "--force", action="store_true", help="Force re-download and overwrite existing ZIP/extracted models" ) parser.add_argument( "--verify-only", action="store_true", help="Only validate existing models, skip download/extraction" ) parser.add_argument( "--keep-zip", action="store_true", help="Keep the downloaded ZIP archive after successful extraction" ) parser.add_argument( "--sha256", type=str, default=os.getenv("NLPROXY_MODELS_SHA256", ""), help="Expected SHA256 checksum for the downloaded ZIP archive" ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable debug-level logging" ) return parser def main(argv: Optional[List[str]] = None) -> int: """Main CLI execution flow.""" parser = create_parser() args = parser.parse_args(argv) setup_logging("DEBUG" if args.verbose else os.getenv("NLPROXY_LOG_LEVEL", "INFO")) # Step 0: Resolve & ensure directory models_dir = args.models_dir.resolve() ensure_models_dir(models_dir) logger.info(f"Target directory: {models_dir}") # Step 1: Verify existing models def check_valid() -> bool: if not models_dir.exists(): return False return all((models_dir / d).is_dir() for d in EXPECTED_MODEL_DIRS) if check_valid() and not args.force: if args.verify_only: success_exit("✅ All required models are present and valid.") logger.info("✅ Models already extracted. Use --force to re-download/extract.") return 0 if args.verify_only: error_exit("❌ Model validation failed. Run without --verify-only to download.", code=1) # Step 2: Resolve ZIP path & download if missing zip_path = models_dir / ZIP_FILENAME if not zip_path.exists() or args.force: if zip_path.exists(): logger.info("--force flag set. Removing old ZIP...") zip_path.unlink(missing_ok=True) logger.info("ZIP archive not found. Starting download...") if not download_with_retries(args.url, zip_path): error_exit("❌ Download failed. Aborting.", code=2) else: logger.info("✅ ZIP archive already present. Skipping download.") # Step 3: Verify ZIP checksum if configured if args.sha256 or os.getenv("NLPROXY_MODELS_SHA256"): try: manager = ModelManager.get_instance(models_dir=str(models_dir)) manager.verify_zip_checksum(zip_path, expected_hash=args.sha256 or None) except Exception as e: error_exit(f"❌ ZIP checksum verification failed: {e}", code=4) # Step 4: Extract & Validate if not extract_and_validate(zip_path, models_dir, EXPECTED_MODEL_DIRS): error_exit("❌ Extraction or validation failed. Check logs.", code=3) # Step 4: Cleanup cleanup_temp(zip_path, keep_zip=args.keep_zip) success_exit("🎉 Model setup complete. NLProxy is ready for offline use.") return 0 if __name__ == "__main__": sys.exit(main())