Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Enterprise-grade model downloader and extractor for NLProxy. | |
| This CLI utility ensures that all required ONNX models are available locally | |
| for offline inference. It follows a strict, idempotent, and atomic workflow: | |
| 1. Verifies/creates the target `models/` directory. | |
| 2. Checks if `nlproxy_models.zip` exists inside it. | |
| 3. If missing, downloads the ZIP from a configurable URL with exponential backoff. | |
| 4. Extracts the archive safely (prevents path traversal, validates integrity). | |
| 5. Validates that all 3 expected model directories are present and non-empty. | |
| 6. Cleans up temporary files and reports status. | |
| Usage | |
| ----- | |
| # Basic usage (uses NLPROXY_MODELS_URL env var or --url flag) | |
| $ python -m nlproxy download_models | |
| Configuration | |
| ------------- | |
| Environment variables: | |
| NLPROXY_MODELS_URL Direct download URL for nlproxy_models.zip | |
| NLPROXY_LOG_LEVEL Logging level: DEBUG, INFO, WARNING, ERROR | |
| NLPROXY_PROXY HTTP/HTTPS proxy URL (optional) | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| import zipfile | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import requests | |
| from requests.exceptions import RequestException | |
| from nlproxy.core.model_manager import ModelManager | |
| try: | |
| from tqdm import tqdm | |
| _TQDM_AVAILABLE = True | |
| except ImportError: | |
| _TQDM_AVAILABLE = False | |
| tqdm = None # type: ignore | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # CONFIGURATION CONSTANTS | |
| # ============================================================================= | |
| DEFAULT_MODELS_DIR: Path = Path("nlproxy") / "models" | |
| ZIP_FILENAME: str = "nlproxy_models.zip" | |
| # Leave empty. Override via --url CLI flag or NLPROXY_MODELS_URL env var. | |
| DEFAULT_MODELS_URL: str = os.getenv("NLPROXY_MODELS_URL", "https://github.com/intellideep/nlproxy/releases/download/free_models/nlproxy_models.zip") | |
| # Expected subdirectories after extraction. Used for strict validation. | |
| EXPECTED_MODEL_DIRS: List[str] = [ | |
| "all-MiniLM-L6-v2", | |
| "nli-distilroberta-base", | |
| "distilgpt2", | |
| ] | |
| # Network & Retry Configuration | |
| DOWNLOAD_TIMEOUT: int = 600 # 10 minutes | |
| MAX_RETRIES: int = 3 | |
| RETRY_BACKOFF_BASE: float = 2.0 | |
| CHUNK_SIZE: int = 8192 | |
| TEMP_PREFIX: str = "nlproxy_dl_" | |
| # ============================================================================= | |
| # UTILITIES & LOGGING | |
| # ============================================================================= | |
| def setup_logging(level: str = "INFO") -> None: | |
| """Configure structured console logging.""" | |
| numeric_level = getattr(logging, level.upper(), logging.INFO) | |
| logging.basicConfig( | |
| level=numeric_level, | |
| format="%(asctime)s [%(levelname)-8s] %(message)s", | |
| datefmt="%H:%M:%S", | |
| stream=sys.stdout, | |
| ) | |
| def log_progress(description: str, total: int, iterable): | |
| """Wraps an iterable with a progress bar if tqdm is available.""" | |
| if _TQDM_AVAILABLE and tqdm: | |
| return tqdm( | |
| iterable, | |
| total=total, | |
| desc=description, | |
| unit="B", | |
| unit_scale=True, | |
| unit_divisor=1024 | |
| ) | |
| # Fallback: periodic logging every 10% | |
| chunk_count = 0 | |
| last_log = 0.0 | |
| for item in iterable: | |
| yield item | |
| chunk_count += 1 | |
| progress = (chunk_count * CHUNK_SIZE / total * 100) if total > 0 else 0.0 | |
| if progress - last_log >= 10.0: | |
| logger.info(f" {description}: {progress:.1f}%") | |
| last_log = progress | |
| def error_exit(message: str, code: int = 1) -> None: | |
| """Print error message and exit with code.""" | |
| logger.error(message) | |
| sys.exit(code) | |
| def success_exit(message: str) -> None: | |
| """Print success message and exit cleanly.""" | |
| logger.info(message) | |
| sys.exit(0) | |
| # ============================================================================= | |
| # CORE WORKFLOW FUNCTIONS | |
| # ============================================================================= | |
| def ensure_models_dir(models_dir: Path) -> None: | |
| """Create models directory if it doesn't exist.""" | |
| if not models_dir.exists(): | |
| logger.info(f"Creating models directory: {models_dir}") | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| elif not models_dir.is_dir(): | |
| error_exit(f"Path exists but is not a directory: {models_dir}") | |
| def download_with_retries(url: str, dest_path: Path) -> bool: | |
| """ | |
| Download a file from URL with exponential backoff retries. | |
| Saves to a temporary file first, then atomically moves to dest_path. | |
| """ | |
| if not url: | |
| raise ValueError( | |
| "Download URL is not set. Provide via --url flag or NLPROXY_MODELS_URL env var." | |
| ) | |
| logger.info(f"Starting download: {url}") | |
| for attempt in range(1, MAX_RETRIES + 1): | |
| tmp_path = None | |
| try: | |
| # requests maneja los redirects de GitHub a S3 de forma nativa | |
| with requests.get( | |
| url, | |
| stream=True, | |
| timeout=DOWNLOAD_TIMEOUT, | |
| headers={"User-Agent": "NLProxy-CLI/1.0"}, | |
| allow_redirects=True | |
| ) as response: | |
| response.raise_for_status() | |
| total_size = int(response.headers.get("Content-Length", 0)) | |
| # Write to temp file first for atomicity | |
| with tempfile.NamedTemporaryFile( | |
| delete=False, prefix=TEMP_PREFIX, dir=dest_path.parent | |
| ) as tmp_file: | |
| tmp_path = Path(tmp_file.name) | |
| # π BULLETPROOF STREAMING: Leer directamente del socket (response.raw) | |
| if _TQDM_AVAILABLE and tqdm and total_size > 0: | |
| with tqdm( | |
| total=total_size, | |
| unit='B', | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| desc=f"Attempt {attempt}/{MAX_RETRIES}" | |
| ) as bar: | |
| # Wrapper para actualizar la barra de progreso mientras se copia | |
| class TqdmReader: | |
| def __init__(self, raw, bar): | |
| self.raw = raw | |
| self.bar = bar | |
| def read(self, size=-1): | |
| data = self.raw.read(size) | |
| self.bar.update(len(data)) | |
| return data | |
| # shutil.copyfileobj es la forma nativa y mΓ‘s rΓ‘pida de copiar streams | |
| shutil.copyfileobj(TqdmReader(response.raw, bar), tmp_file) | |
| else: | |
| # Fallback sin barra de progreso | |
| shutil.copyfileobj(response.raw, tmp_file) | |
| # Verificar tamaΓ±o final | |
| downloaded = tmp_path.stat().st_size | |
| if total_size > 0 and downloaded < total_size: | |
| raise RuntimeError(f"Download incomplete: {downloaded}/{total_size} bytes") | |
| # Atomic move to final destination | |
| tmp_path.replace(dest_path) | |
| logger.info(f"Download complete: {dest_path} ({downloaded / 1024 / 1024:.2f} MB)") | |
| return True | |
| except (RequestException, OSError, RuntimeError) as e: | |
| logger.warning(f"Download failed (attempt {attempt}/{MAX_RETRIES}): {e}") | |
| if tmp_path and Path(tmp_path).exists(): | |
| Path(tmp_path).unlink(missing_ok=True) | |
| if attempt < MAX_RETRIES: | |
| backoff = RETRY_BACKOFF_BASE ** (attempt - 1) | |
| logger.info(f"Retrying in {backoff:.1f}s...") | |
| time.sleep(backoff) | |
| else: | |
| logger.error("Max retries exceeded. Download aborted.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Unexpected error during download: {e}") | |
| if tmp_path and Path(tmp_path).exists(): | |
| Path(tmp_path).unlink(missing_ok=True) | |
| return False | |
| def extract_and_validate(zip_path: Path, models_dir: Path, expected_dirs: List[str]) -> bool: | |
| """ | |
| Extract ZIP archive and validate directory structure. | |
| Performs integrity check and safe extraction (prevents path traversal). | |
| """ | |
| logger.info(f"Extracting {zip_path.name} to {models_dir}...") | |
| try: | |
| with zipfile.ZipFile(zip_path, 'r') as zf: | |
| # 1. Integrity check | |
| bad_file = zf.testzip() | |
| if bad_file is not None: | |
| raise zipfile.BadZipFile(f"Archive corruption detected in: {bad_file}") | |
| # 2. Safe extraction (prevent path traversal attacks) | |
| for member in zf.namelist(): | |
| member_path = Path(member) | |
| if member_path.is_absolute() or ".." in member_path.parts: | |
| raise zipfile.BadZipFile(f"Unsafe path in archive: {member}") | |
| zf.extractall(models_dir) | |
| logger.info("Extraction successful.") | |
| # 3. Structural validation | |
| missing = [d for d in expected_dirs if not (models_dir / d).is_dir()] | |
| if missing: | |
| logger.error(f"Validation failed. Missing directories: {missing}") | |
| return False | |
| # Verify each directory contains at least one valid model file | |
| for dir_name in expected_dirs: | |
| dir_path = models_dir / dir_name | |
| contents = list(dir_path.iterdir()) | |
| if not any(f.suffix in {".onnx", ".bin", ".json", ".txt", ".py"} for f in contents): | |
| logger.error(f"Directory appears empty or invalid: {dir_name}") | |
| return False | |
| logger.info("β Model structure validated successfully.") | |
| return True | |
| except zipfile.BadZipFile as e: | |
| logger.error(f"Invalid or corrupted ZIP file: {e}") | |
| return False | |
| except PermissionError as e: | |
| logger.error(f"Permission denied during extraction: {e}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Extraction failed: {e}") | |
| return False | |
| def cleanup_temp(zip_path: Path, keep_zip: bool = False) -> None: | |
| """Remove temporary/extracted files if not requested to keep them.""" | |
| if not keep_zip and zip_path.exists(): | |
| try: | |
| zip_path.unlink() | |
| logger.debug(f"Cleaned up ZIP archive: {zip_path}") | |
| except OSError as e: | |
| logger.warning(f"Failed to clean up {zip_path}: {e}") | |
| # ============================================================================= | |
| # CLI ENTRY POINT | |
| # ============================================================================= | |
| def create_parser() -> argparse.ArgumentParser: | |
| """Build CLI argument parser.""" | |
| parser = argparse.ArgumentParser( | |
| prog="nlproxy-download-models", | |
| description="Enterprise-grade model downloader for NLProxy SDK.", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Environment Variables: | |
| NLPROXY_MODELS_URL Direct URL to nlproxy_models.zip | |
| NLPROXY_LOG_LEVEL DEBUG, INFO, WARNING, ERROR (default: INFO) | |
| NLPROXY_PROXY HTTP/HTTPS proxy URL (optional) | |
| Examples: | |
| python -m nlproxy download_models | |
| python -m nlproxy download_models --force --url https://example.com/models.zip | |
| python -m nlproxy download_models --verify-only --models-dir /opt/models | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--models-dir", type=Path, default=DEFAULT_MODELS_DIR, | |
| help="Target directory for extracted models (default: ./models)" | |
| ) | |
| parser.add_argument( | |
| "--url", type=str, default=DEFAULT_MODELS_URL, | |
| help="Direct download URL for nlproxy_models.zip" | |
| ) | |
| parser.add_argument( | |
| "--force", action="store_true", | |
| help="Force re-download and overwrite existing ZIP/extracted models" | |
| ) | |
| parser.add_argument( | |
| "--verify-only", action="store_true", | |
| help="Only validate existing models, skip download/extraction" | |
| ) | |
| parser.add_argument( | |
| "--keep-zip", action="store_true", | |
| help="Keep the downloaded ZIP archive after successful extraction" | |
| ) | |
| parser.add_argument( | |
| "--sha256", type=str, default=os.getenv("NLPROXY_MODELS_SHA256", ""), | |
| help="Expected SHA256 checksum for the downloaded ZIP archive" | |
| ) | |
| parser.add_argument( | |
| "-v", "--verbose", action="store_true", | |
| help="Enable debug-level logging" | |
| ) | |
| return parser | |
| def main(argv: Optional[List[str]] = None) -> int: | |
| """Main CLI execution flow.""" | |
| parser = create_parser() | |
| args = parser.parse_args(argv) | |
| setup_logging("DEBUG" if args.verbose else os.getenv("NLPROXY_LOG_LEVEL", "INFO")) | |
| # Step 0: Resolve & ensure directory | |
| models_dir = args.models_dir.resolve() | |
| ensure_models_dir(models_dir) | |
| logger.info(f"Target directory: {models_dir}") | |
| # Step 1: Verify existing models | |
| def check_valid() -> bool: | |
| if not models_dir.exists(): | |
| return False | |
| return all((models_dir / d).is_dir() for d in EXPECTED_MODEL_DIRS) | |
| if check_valid() and not args.force: | |
| if args.verify_only: | |
| success_exit("β All required models are present and valid.") | |
| logger.info("β Models already extracted. Use --force to re-download/extract.") | |
| return 0 | |
| if args.verify_only: | |
| error_exit("β Model validation failed. Run without --verify-only to download.", code=1) | |
| # Step 2: Resolve ZIP path & download if missing | |
| zip_path = models_dir / ZIP_FILENAME | |
| if not zip_path.exists() or args.force: | |
| if zip_path.exists(): | |
| logger.info("--force flag set. Removing old ZIP...") | |
| zip_path.unlink(missing_ok=True) | |
| logger.info("ZIP archive not found. Starting download...") | |
| if not download_with_retries(args.url, zip_path): | |
| error_exit("β Download failed. Aborting.", code=2) | |
| else: | |
| logger.info("β ZIP archive already present. Skipping download.") | |
| # Step 3: Verify ZIP checksum if configured | |
| if args.sha256 or os.getenv("NLPROXY_MODELS_SHA256"): | |
| try: | |
| manager = ModelManager.get_instance(models_dir=str(models_dir)) | |
| manager.verify_zip_checksum(zip_path, expected_hash=args.sha256 or None) | |
| except Exception as e: | |
| error_exit(f"β ZIP checksum verification failed: {e}", code=4) | |
| # Step 4: Extract & Validate | |
| if not extract_and_validate(zip_path, models_dir, EXPECTED_MODEL_DIRS): | |
| error_exit("β Extraction or validation failed. Check logs.", code=3) | |
| # Step 4: Cleanup | |
| cleanup_temp(zip_path, keep_zip=args.keep_zip) | |
| success_exit("π Model setup complete. NLProxy is ready for offline use.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |