fishapi / start_api.py
kamau1's picture
Initial commit
bcc2f7b verified
#!/usr/bin/env python3
"""
Startup script for the Marine Species Identification API.
This script handles model downloading and API startup.
"""
import asyncio
import sys
import os
from pathlib import Path
# Add the app directory to Python path
sys.path.insert(0, str(Path(__file__).parent))
from app.core.config import settings
from app.core.logging import setup_logging, get_logger
from app.utils.model_utils import (
download_model_from_hf,
verify_model_file,
setup_model_directory,
list_available_files
)
# Setup logging
setup_logging()
logger = get_logger(__name__)
async def ensure_model_available():
"""Ensure the model is downloaded and available."""
logger.info("πŸ” Checking model availability...")
# Setup model directory
model_dir = setup_model_directory()
logger.info(f"Model directory: {model_dir}")
# Check if model file exists
if verify_model_file(settings.MODEL_PATH):
logger.info("βœ… Model file found and verified")
return True
logger.info("πŸ“₯ Model not found locally, attempting to download...")
try:
# List available files in the repository
logger.info(f"Checking repository: {settings.HUGGINGFACE_REPO}")
available_files = list_available_files(settings.HUGGINGFACE_REPO)
if available_files:
logger.info(f"Available files in repository:")
for file in available_files[:10]: # Show first 10 files
logger.info(f" - {file}")
if len(available_files) > 10:
logger.info(f" ... and {len(available_files) - 10} more files")
# Download the model
model_filename = f"{settings.MODEL_NAME}.pt"
if model_filename in available_files:
download_model_from_hf(
repo_id=settings.HUGGINGFACE_REPO,
model_filename=model_filename,
local_dir=model_dir,
force_download=False
)
# Verify the downloaded model
if verify_model_file(settings.MODEL_PATH):
logger.info("βœ… Model downloaded and verified successfully")
return True
else:
logger.error("❌ Downloaded model failed verification")
return False
else:
logger.error(f"❌ Model file '{model_filename}' not found in repository")
logger.info("Available .pt files:")
pt_files = [f for f in available_files if f.endswith('.pt')]
for pt_file in pt_files:
logger.info(f" - {pt_file}")
return False
except Exception as e:
logger.error(f"❌ Failed to download model: {str(e)}")
return False
def start_api():
"""Start the FastAPI application."""
import uvicorn
logger.info("πŸš€ Starting Marine Species Identification API...")
logger.info(f"Host: {settings.HOST}")
logger.info(f"Port: {settings.PORT}")
logger.info(f"Docs: http://{settings.HOST}:{settings.PORT}/docs")
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
reload=False,
log_level="info",
access_log=True
)
async def main():
"""Main startup function."""
logger.info("🐟 Marine Species Identification API Startup")
logger.info("=" * 50)
# Check model availability
model_available = await ensure_model_available()
if not model_available:
logger.warning("⚠️ Model not available - API will start but inference may fail")
logger.info("The API will still start and you can check /health for status")
logger.info("=" * 50)
# Start the API
start_api()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("πŸ›‘ API startup interrupted by user")
except Exception as e:
logger.error(f"❌ Failed to start API: {str(e)}")
sys.exit(1)