|
|
|
|
|
""" |
|
|
Startup script for the Marine Species Identification API. |
|
|
This script handles model downloading and API startup. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from app.core.config import settings |
|
|
from app.core.logging import setup_logging, get_logger |
|
|
from app.utils.model_utils import ( |
|
|
download_model_from_hf, |
|
|
verify_model_file, |
|
|
setup_model_directory, |
|
|
list_available_files |
|
|
) |
|
|
|
|
|
|
|
|
setup_logging() |
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
async def ensure_model_available(): |
|
|
"""Ensure the model is downloaded and available.""" |
|
|
logger.info("π Checking model availability...") |
|
|
|
|
|
|
|
|
model_dir = setup_model_directory() |
|
|
logger.info(f"Model directory: {model_dir}") |
|
|
|
|
|
|
|
|
if verify_model_file(settings.MODEL_PATH): |
|
|
logger.info("β
Model file found and verified") |
|
|
return True |
|
|
|
|
|
logger.info("π₯ Model not found locally, attempting to download...") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info(f"Checking repository: {settings.HUGGINGFACE_REPO}") |
|
|
available_files = list_available_files(settings.HUGGINGFACE_REPO) |
|
|
|
|
|
if available_files: |
|
|
logger.info(f"Available files in repository:") |
|
|
for file in available_files[:10]: |
|
|
logger.info(f" - {file}") |
|
|
if len(available_files) > 10: |
|
|
logger.info(f" ... and {len(available_files) - 10} more files") |
|
|
|
|
|
|
|
|
model_filename = f"{settings.MODEL_NAME}.pt" |
|
|
|
|
|
if model_filename in available_files: |
|
|
download_model_from_hf( |
|
|
repo_id=settings.HUGGINGFACE_REPO, |
|
|
model_filename=model_filename, |
|
|
local_dir=model_dir, |
|
|
force_download=False |
|
|
) |
|
|
|
|
|
|
|
|
if verify_model_file(settings.MODEL_PATH): |
|
|
logger.info("β
Model downloaded and verified successfully") |
|
|
return True |
|
|
else: |
|
|
logger.error("β Downloaded model failed verification") |
|
|
return False |
|
|
else: |
|
|
logger.error(f"β Model file '{model_filename}' not found in repository") |
|
|
logger.info("Available .pt files:") |
|
|
pt_files = [f for f in available_files if f.endswith('.pt')] |
|
|
for pt_file in pt_files: |
|
|
logger.info(f" - {pt_file}") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to download model: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def start_api(): |
|
|
"""Start the FastAPI application.""" |
|
|
import uvicorn |
|
|
|
|
|
logger.info("π Starting Marine Species Identification API...") |
|
|
logger.info(f"Host: {settings.HOST}") |
|
|
logger.info(f"Port: {settings.PORT}") |
|
|
logger.info(f"Docs: http://{settings.HOST}:{settings.PORT}/docs") |
|
|
|
|
|
uvicorn.run( |
|
|
"app.main:app", |
|
|
host=settings.HOST, |
|
|
port=settings.PORT, |
|
|
reload=False, |
|
|
log_level="info", |
|
|
access_log=True |
|
|
) |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Main startup function.""" |
|
|
logger.info("π Marine Species Identification API Startup") |
|
|
logger.info("=" * 50) |
|
|
|
|
|
|
|
|
model_available = await ensure_model_available() |
|
|
|
|
|
if not model_available: |
|
|
logger.warning("β οΈ Model not available - API will start but inference may fail") |
|
|
logger.info("The API will still start and you can check /health for status") |
|
|
|
|
|
logger.info("=" * 50) |
|
|
|
|
|
|
|
|
start_api() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
asyncio.run(main()) |
|
|
except KeyboardInterrupt: |
|
|
logger.info("π API startup interrupted by user") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to start API: {str(e)}") |
|
|
sys.exit(1) |
|
|
|