""" Download EfficientNet model from cloud storage if not present. This script runs at application startup to download the model if needed. """ import os import sys import logging from pathlib import Path import requests logger = logging.getLogger(__name__) def download_efficientnet_model(): """ Download EfficientNet optimized model if it doesn't exist. Supports two methods: 1. Hugging Face Hub (set HF_MODEL_REPO environment variable) 2. Direct URL download (set EFFICIENTNET_MODEL_URL environment variable) """ # Get base directory base_dir = Path(__file__).resolve().parent.parent models_dir = base_dir / "models" / "optimized_models" models_dir.mkdir(parents=True, exist_ok=True) model_path = models_dir / "efficientnet_efficient_best_model_quantized.pth" # Check if model already exists if model_path.exists(): size_mb = model_path.stat().st_size / (1024 * 1024) logger.info(f"EfficientNet model already exists ({size_mb:.1f}MB)") return True # Try Hugging Face Hub first hf_repo = os.environ.get("HF_MODEL_REPO") if hf_repo: try: from huggingface_hub import hf_hub_download logger.info(f"Downloading model from Hugging Face Hub: {hf_repo}") downloaded_path = hf_hub_download( repo_id=hf_repo, filename="efficientnet_efficient_best_model_quantized.pth", cache_dir=str(models_dir), local_dir=str(models_dir), local_dir_use_symlinks=False ) # Copy to expected location if needed (HF Hub creates nested structure) downloaded_path = Path(downloaded_path) if downloaded_path != model_path: import shutil logger.info(f"Copying model from {downloaded_path} to {model_path}") shutil.copy2(downloaded_path, model_path) logger.info(f"Model copied to expected location: {model_path}") size_mb = model_path.stat().st_size / (1024 * 1024) logger.info(f"Model downloaded from HF Hub successfully ({size_mb:.1f}MB) at {model_path}") return True except ImportError: logger.warning("huggingface_hub not installed. Install with: pip install huggingface_hub") except Exception as e: logger.warning(f"Failed to download from HF Hub: {e}. Trying direct URL...") # Fallback to direct URL download model_url = os.environ.get("EFFICIENTNET_MODEL_URL") if not model_url: logger.warning("Neither HF_MODEL_REPO nor EFFICIENTNET_MODEL_URL is set.") logger.warning("Model will not be downloaded. Set one of these environment variables.") return False try: logger.info(f"Downloading EfficientNet model from {model_url}...") logger.info("This may take a few minutes (model is ~245MB)...") # Download with progress response = requests.get(model_url, stream=True, timeout=300) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) downloaded += len(chunk) if total_size > 0: percent = (downloaded / total_size) * 100 if downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB logger.info(f"Downloaded {downloaded / (1024 * 1024):.1f}MB / {total_size / (1024 * 1024):.1f}MB ({percent:.1f}%)") size_mb = model_path.stat().st_size / (1024 * 1024) logger.info(f"EfficientNet model downloaded successfully ({size_mb:.1f}MB)") return True except requests.exceptions.RequestException as e: logger.error(f"Failed to download model: {e}") # Clean up partial download if model_path.exists(): model_path.unlink() return False except Exception as e: logger.error(f"Error downloading model: {e}", exc_info=True) # Clean up partial download if model_path.exists(): model_path.unlink() return False if __name__ == "__main__": # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) download_efficientnet_model()