Spaces:
Running
Running
| """ | |
| Download EfficientNet model from cloud storage if not present. | |
| This script runs at application startup to download the model if needed. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| import requests | |
| logger = logging.getLogger(__name__) | |
| def download_efficientnet_model(): | |
| """ | |
| Download EfficientNet optimized model if it doesn't exist. | |
| Supports two methods: | |
| 1. Hugging Face Hub (set HF_MODEL_REPO environment variable) | |
| 2. Direct URL download (set EFFICIENTNET_MODEL_URL environment variable) | |
| """ | |
| # Get base directory | |
| base_dir = Path(__file__).resolve().parent.parent | |
| models_dir = base_dir / "models" / "optimized_models" | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = models_dir / "efficientnet_efficient_best_model_quantized.pth" | |
| # Check if model already exists | |
| if model_path.exists(): | |
| size_mb = model_path.stat().st_size / (1024 * 1024) | |
| logger.info(f"EfficientNet model already exists ({size_mb:.1f}MB)") | |
| return True | |
| # Try Hugging Face Hub first | |
| hf_repo = os.environ.get("HF_MODEL_REPO") | |
| if hf_repo: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| logger.info(f"Downloading model from Hugging Face Hub: {hf_repo}") | |
| downloaded_path = hf_hub_download( | |
| repo_id=hf_repo, | |
| filename="efficientnet_efficient_best_model_quantized.pth", | |
| cache_dir=str(models_dir), | |
| local_dir=str(models_dir), | |
| local_dir_use_symlinks=False | |
| ) | |
| # Copy to expected location if needed (HF Hub creates nested structure) | |
| downloaded_path = Path(downloaded_path) | |
| if downloaded_path != model_path: | |
| import shutil | |
| logger.info(f"Copying model from {downloaded_path} to {model_path}") | |
| shutil.copy2(downloaded_path, model_path) | |
| logger.info(f"Model copied to expected location: {model_path}") | |
| size_mb = model_path.stat().st_size / (1024 * 1024) | |
| logger.info(f"Model downloaded from HF Hub successfully ({size_mb:.1f}MB) at {model_path}") | |
| return True | |
| except ImportError: | |
| logger.warning("huggingface_hub not installed. Install with: pip install huggingface_hub") | |
| except Exception as e: | |
| logger.warning(f"Failed to download from HF Hub: {e}. Trying direct URL...") | |
| # Fallback to direct URL download | |
| model_url = os.environ.get("EFFICIENTNET_MODEL_URL") | |
| if not model_url: | |
| logger.warning("Neither HF_MODEL_REPO nor EFFICIENTNET_MODEL_URL is set.") | |
| logger.warning("Model will not be downloaded. Set one of these environment variables.") | |
| return False | |
| try: | |
| logger.info(f"Downloading EfficientNet model from {model_url}...") | |
| logger.info("This may take a few minutes (model is ~245MB)...") | |
| # Download with progress | |
| response = requests.get(model_url, stream=True, timeout=300) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| downloaded = 0 | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 0: | |
| percent = (downloaded / total_size) * 100 | |
| if downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB | |
| logger.info(f"Downloaded {downloaded / (1024 * 1024):.1f}MB / {total_size / (1024 * 1024):.1f}MB ({percent:.1f}%)") | |
| size_mb = model_path.stat().st_size / (1024 * 1024) | |
| logger.info(f"EfficientNet model downloaded successfully ({size_mb:.1f}MB)") | |
| return True | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to download model: {e}") | |
| # Clean up partial download | |
| if model_path.exists(): | |
| model_path.unlink() | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {e}", exc_info=True) | |
| # Clean up partial download | |
| if model_path.exists(): | |
| model_path.unlink() | |
| return False | |
| if __name__ == "__main__": | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| download_efficientnet_model() | |