ImageCaptionner / scripts /download_model.py
AOUNZakaria's picture
Update scripts/download_model.py
c8cbd99 verified
"""
Download EfficientNet model from cloud storage if not present.
This script runs at application startup to download the model if needed.
"""
import os
import sys
import logging
from pathlib import Path
import requests
logger = logging.getLogger(__name__)
def download_efficientnet_model():
"""
Download EfficientNet optimized model if it doesn't exist.
Supports two methods:
1. Hugging Face Hub (set HF_MODEL_REPO environment variable)
2. Direct URL download (set EFFICIENTNET_MODEL_URL environment variable)
"""
# Get base directory
base_dir = Path(__file__).resolve().parent.parent
models_dir = base_dir / "models" / "optimized_models"
models_dir.mkdir(parents=True, exist_ok=True)
model_path = models_dir / "efficientnet_efficient_best_model_quantized.pth"
# Check if model already exists
if model_path.exists():
size_mb = model_path.stat().st_size / (1024 * 1024)
logger.info(f"EfficientNet model already exists ({size_mb:.1f}MB)")
return True
# Try Hugging Face Hub first
hf_repo = os.environ.get("HF_MODEL_REPO")
if hf_repo:
try:
from huggingface_hub import hf_hub_download
logger.info(f"Downloading model from Hugging Face Hub: {hf_repo}")
downloaded_path = hf_hub_download(
repo_id=hf_repo,
filename="efficientnet_efficient_best_model_quantized.pth",
cache_dir=str(models_dir),
local_dir=str(models_dir),
local_dir_use_symlinks=False
)
# Copy to expected location if needed (HF Hub creates nested structure)
downloaded_path = Path(downloaded_path)
if downloaded_path != model_path:
import shutil
logger.info(f"Copying model from {downloaded_path} to {model_path}")
shutil.copy2(downloaded_path, model_path)
logger.info(f"Model copied to expected location: {model_path}")
size_mb = model_path.stat().st_size / (1024 * 1024)
logger.info(f"Model downloaded from HF Hub successfully ({size_mb:.1f}MB) at {model_path}")
return True
except ImportError:
logger.warning("huggingface_hub not installed. Install with: pip install huggingface_hub")
except Exception as e:
logger.warning(f"Failed to download from HF Hub: {e}. Trying direct URL...")
# Fallback to direct URL download
model_url = os.environ.get("EFFICIENTNET_MODEL_URL")
if not model_url:
logger.warning("Neither HF_MODEL_REPO nor EFFICIENTNET_MODEL_URL is set.")
logger.warning("Model will not be downloaded. Set one of these environment variables.")
return False
try:
logger.info(f"Downloading EfficientNet model from {model_url}...")
logger.info("This may take a few minutes (model is ~245MB)...")
# Download with progress
response = requests.get(model_url, stream=True, timeout=300)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
if downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB
logger.info(f"Downloaded {downloaded / (1024 * 1024):.1f}MB / {total_size / (1024 * 1024):.1f}MB ({percent:.1f}%)")
size_mb = model_path.stat().st_size / (1024 * 1024)
logger.info(f"EfficientNet model downloaded successfully ({size_mb:.1f}MB)")
return True
except requests.exceptions.RequestException as e:
logger.error(f"Failed to download model: {e}")
# Clean up partial download
if model_path.exists():
model_path.unlink()
return False
except Exception as e:
logger.error(f"Error downloading model: {e}", exc_info=True)
# Clean up partial download
if model_path.exists():
model_path.unlink()
return False
if __name__ == "__main__":
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
download_efficientnet_model()