#!/usr/bin/env python3 """ Startup script for the Marine Species Identification API. This script handles model downloading and API startup. """ import asyncio import sys import os from pathlib import Path # Add the app directory to Python path sys.path.insert(0, str(Path(__file__).parent)) from app.core.config import settings from app.core.logging import setup_logging, get_logger from app.utils.model_utils import ( download_model_from_hf, verify_model_file, setup_model_directory, list_available_files ) # Setup logging setup_logging() logger = get_logger(__name__) async def ensure_model_available(): """Ensure the model is downloaded and available.""" logger.info("🔍 Checking model availability...") # Setup model directory model_dir = setup_model_directory() logger.info(f"Model directory: {model_dir}") # Check if model file exists if verify_model_file(settings.MODEL_PATH): logger.info("✅ Model file found and verified") return True logger.info("📥 Model not found locally, attempting to download...") try: # List available files in the repository logger.info(f"Checking repository: {settings.HUGGINGFACE_REPO}") available_files = list_available_files(settings.HUGGINGFACE_REPO) if available_files: logger.info(f"Available files in repository:") for file in available_files[:10]: # Show first 10 files logger.info(f" - {file}") if len(available_files) > 10: logger.info(f" ... and {len(available_files) - 10} more files") # Download the model model_filename = f"{settings.MODEL_NAME}.pt" if model_filename in available_files: download_model_from_hf( repo_id=settings.HUGGINGFACE_REPO, model_filename=model_filename, local_dir=model_dir, force_download=False ) # Verify the downloaded model if verify_model_file(settings.MODEL_PATH): logger.info("✅ Model downloaded and verified successfully") return True else: logger.error("❌ Downloaded model failed verification") return False else: logger.error(f"❌ Model file '{model_filename}' not found in repository") logger.info("Available .pt files:") pt_files = [f for f in available_files if f.endswith('.pt')] for pt_file in pt_files: logger.info(f" - {pt_file}") return False except Exception as e: logger.error(f"❌ Failed to download model: {str(e)}") return False def start_api(): """Start the FastAPI application.""" import uvicorn logger.info("🚀 Starting Marine Species Identification API...") logger.info(f"Host: {settings.HOST}") logger.info(f"Port: {settings.PORT}") logger.info(f"Docs: http://{settings.HOST}:{settings.PORT}/docs") uvicorn.run( "app.main:app", host=settings.HOST, port=settings.PORT, reload=False, log_level="info", access_log=True ) async def main(): """Main startup function.""" logger.info("🐟 Marine Species Identification API Startup") logger.info("=" * 50) # Check model availability model_available = await ensure_model_available() if not model_available: logger.warning("⚠️ Model not available - API will start but inference may fail") logger.info("The API will still start and you can check /health for status") logger.info("=" * 50) # Start the API start_api() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: logger.info("🛑 API startup interrupted by user") except Exception as e: logger.error(f"❌ Failed to start API: {str(e)}") sys.exit(1)