fishapi / app /services /model_service.py
kamau1's picture
Fix Pydantic field conflicts and model loading issues; add test and upload scripts for HF deployment
0c8b8e3 verified
"""
Model service for managing YOLOv5 model lifecycle and operations.
"""
import os
from pathlib import Path
from typing import Dict, Optional
from huggingface_hub import hf_hub_download
from app.core.config import settings
from app.core.logging import get_logger
from app.models.yolo import MarineSpeciesYOLO, get_model
logger = get_logger(__name__)
class ModelService:
"""Service for managing the marine species detection model."""
def __init__(self):
self._model: Optional[MarineSpeciesYOLO] = None
self._class_names: Optional[Dict[int, str]] = None
async def ensure_model_available(self) -> None:
"""
Ensure the model is downloaded and available.
Downloads from HuggingFace Hub if not present locally.
"""
model_path = Path(settings.MODEL_PATH)
# Check if model exists locally
if not model_path.exists():
logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...")
try:
await self._download_model()
except Exception as e:
logger.error(f"Failed to download model: {e}")
# Continue anyway - the API can still run without the model
# Load class names if available
await self._load_class_names()
# Try to initialize the model to catch loading errors early
try:
self.get_model()
logger.info("Model loaded successfully during startup")
except Exception as e:
logger.error(f"Model failed to load during startup: {e}")
# Don't fail startup - let health checks handle this
async def _download_model(self) -> None:
"""Download model from HuggingFace Hub."""
try:
# Create models directory if it doesn't exist
model_dir = Path(settings.MODEL_PATH).parent
model_dir.mkdir(parents=True, exist_ok=True)
# Download the model file
logger.info(f"Downloading model from {settings.HUGGINGFACE_REPO}")
# Download the .pt model file
model_filename = f"{settings.MODEL_NAME}.pt"
downloaded_path = hf_hub_download(
repo_id=settings.HUGGINGFACE_REPO,
filename=model_filename,
cache_dir=str(model_dir.parent / ".cache"),
local_dir=str(model_dir),
local_dir_use_symlinks=False
)
logger.info(f"Model downloaded successfully to: {downloaded_path}")
# Also download the .names file if available
try:
names_filename = f"{settings.MODEL_NAME}.names"
names_path = hf_hub_download(
repo_id=settings.HUGGINGFACE_REPO,
filename=names_filename,
cache_dir=str(model_dir.parent / ".cache"),
local_dir=str(model_dir),
local_dir_use_symlinks=False
)
logger.info(f"Class names file downloaded to: {names_path}")
except Exception as e:
logger.warning(f"Could not download .names file: {str(e)}")
except Exception as e:
logger.error(f"Failed to download model: {str(e)}")
raise RuntimeError(f"Model download failed: {str(e)}")
async def _load_class_names(self) -> None:
"""Load class names from .names file."""
names_file = Path(settings.MODEL_PATH).with_suffix('.names')
if names_file.exists():
try:
class_names = {}
with open(names_file, 'r') as f:
for idx, line in enumerate(f):
class_names[idx] = line.strip()
self._class_names = class_names
logger.info(f"Loaded {len(class_names)} class names")
except Exception as e:
logger.error(f"Failed to load class names: {str(e)}")
else:
logger.warning(f"Class names file not found: {names_file}")
def get_model(self) -> MarineSpeciesYOLO:
"""
Get the model instance.
Returns:
MarineSpeciesYOLO instance
"""
if self._model is None:
self._model = get_model()
return self._model
def get_class_names(self) -> Optional[Dict[int, str]]:
"""
Get class names mapping.
Returns:
Dictionary mapping class IDs to names
"""
if self._class_names is None:
# Try to get from model
model = self.get_model()
self._class_names = model.get_class_names()
return self._class_names
def get_model_info(self) -> Dict:
"""
Get comprehensive model information.
Returns:
Dictionary with model information
"""
try:
model = self.get_model()
class_names = self.get_class_names()
# Safely get device info
device_info = "unknown"
try:
device_info = str(model.device) if hasattr(model, 'device') else "unknown"
except Exception as e:
logger.warning(f"Could not get device info: {e}")
return {
"model_name": settings.MODEL_NAME,
"total_classes": len(class_names) if class_names else 0,
"device": device_info,
"model_path": settings.MODEL_PATH,
"huggingface_repo": settings.HUGGINGFACE_REPO
}
except Exception as e:
logger.error(f"Failed to get model info: {str(e)}")
# Return basic info even if model fails
return {
"model_name": settings.MODEL_NAME,
"total_classes": 0,
"device": "unknown",
"model_path": settings.MODEL_PATH,
"huggingface_repo": settings.HUGGINGFACE_REPO
}
async def health_check(self) -> Dict:
"""
Perform a health check on the model.
Returns:
Dictionary with health status
"""
try:
model = self.get_model()
model_info = self.get_model_info()
return {
"status": "healthy",
"model_loaded": True,
"model_info": model_info
}
except Exception as e:
logger.error(f"Model health check failed: {str(e)}")
return {
"status": "unhealthy",
"model_loaded": False,
"error": str(e)
}
# Global service instance
model_service = ModelService()