| """ |
| Model service for managing YOLOv5 model lifecycle and operations. |
| """ |
|
|
| import os |
| from pathlib import Path |
| from typing import Dict, Optional |
| from huggingface_hub import hf_hub_download |
|
|
| from app.core.config import settings |
| from app.core.logging import get_logger |
| from app.models.yolo import MarineSpeciesYOLO, get_model |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class ModelService: |
| """Service for managing the marine species detection model.""" |
| |
| def __init__(self): |
| self._model: Optional[MarineSpeciesYOLO] = None |
| self._class_names: Optional[Dict[int, str]] = None |
| |
| async def ensure_model_available(self) -> None: |
| """ |
| Ensure the model is downloaded and available. |
| Downloads from HuggingFace Hub if not present locally. |
| """ |
| model_path = Path(settings.MODEL_PATH) |
|
|
| |
| if not model_path.exists(): |
| logger.info(f"Model not found at {model_path}, downloading from HuggingFace Hub...") |
| try: |
| await self._download_model() |
| except Exception as e: |
| logger.error(f"Failed to download model: {e}") |
| |
|
|
| |
| await self._load_class_names() |
|
|
| |
| try: |
| self.get_model() |
| logger.info("Model loaded successfully during startup") |
| except Exception as e: |
| logger.error(f"Model failed to load during startup: {e}") |
| |
| |
| async def _download_model(self) -> None: |
| """Download model from HuggingFace Hub.""" |
| try: |
| |
| model_dir = Path(settings.MODEL_PATH).parent |
| model_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| logger.info(f"Downloading model from {settings.HUGGINGFACE_REPO}") |
| |
| |
| model_filename = f"{settings.MODEL_NAME}.pt" |
| downloaded_path = hf_hub_download( |
| repo_id=settings.HUGGINGFACE_REPO, |
| filename=model_filename, |
| cache_dir=str(model_dir.parent / ".cache"), |
| local_dir=str(model_dir), |
| local_dir_use_symlinks=False |
| ) |
| |
| logger.info(f"Model downloaded successfully to: {downloaded_path}") |
| |
| |
| try: |
| names_filename = f"{settings.MODEL_NAME}.names" |
| names_path = hf_hub_download( |
| repo_id=settings.HUGGINGFACE_REPO, |
| filename=names_filename, |
| cache_dir=str(model_dir.parent / ".cache"), |
| local_dir=str(model_dir), |
| local_dir_use_symlinks=False |
| ) |
| logger.info(f"Class names file downloaded to: {names_path}") |
| except Exception as e: |
| logger.warning(f"Could not download .names file: {str(e)}") |
| |
| except Exception as e: |
| logger.error(f"Failed to download model: {str(e)}") |
| raise RuntimeError(f"Model download failed: {str(e)}") |
| |
| async def _load_class_names(self) -> None: |
| """Load class names from .names file.""" |
| names_file = Path(settings.MODEL_PATH).with_suffix('.names') |
| |
| if names_file.exists(): |
| try: |
| class_names = {} |
| with open(names_file, 'r') as f: |
| for idx, line in enumerate(f): |
| class_names[idx] = line.strip() |
| |
| self._class_names = class_names |
| logger.info(f"Loaded {len(class_names)} class names") |
| except Exception as e: |
| logger.error(f"Failed to load class names: {str(e)}") |
| else: |
| logger.warning(f"Class names file not found: {names_file}") |
| |
| def get_model(self) -> MarineSpeciesYOLO: |
| """ |
| Get the model instance. |
| |
| Returns: |
| MarineSpeciesYOLO instance |
| """ |
| if self._model is None: |
| self._model = get_model() |
| return self._model |
| |
| def get_class_names(self) -> Optional[Dict[int, str]]: |
| """ |
| Get class names mapping. |
| |
| Returns: |
| Dictionary mapping class IDs to names |
| """ |
| if self._class_names is None: |
| |
| model = self.get_model() |
| self._class_names = model.get_class_names() |
| |
| return self._class_names |
| |
| def get_model_info(self) -> Dict: |
| """ |
| Get comprehensive model information. |
| |
| Returns: |
| Dictionary with model information |
| """ |
| try: |
| model = self.get_model() |
| class_names = self.get_class_names() |
|
|
| |
| device_info = "unknown" |
| try: |
| device_info = str(model.device) if hasattr(model, 'device') else "unknown" |
| except Exception as e: |
| logger.warning(f"Could not get device info: {e}") |
|
|
| return { |
| "model_name": settings.MODEL_NAME, |
| "total_classes": len(class_names) if class_names else 0, |
| "device": device_info, |
| "model_path": settings.MODEL_PATH, |
| "huggingface_repo": settings.HUGGINGFACE_REPO |
| } |
| except Exception as e: |
| logger.error(f"Failed to get model info: {str(e)}") |
| |
| return { |
| "model_name": settings.MODEL_NAME, |
| "total_classes": 0, |
| "device": "unknown", |
| "model_path": settings.MODEL_PATH, |
| "huggingface_repo": settings.HUGGINGFACE_REPO |
| } |
| |
| async def health_check(self) -> Dict: |
| """ |
| Perform a health check on the model. |
| |
| Returns: |
| Dictionary with health status |
| """ |
| try: |
| model = self.get_model() |
| model_info = self.get_model_info() |
| |
| return { |
| "status": "healthy", |
| "model_loaded": True, |
| "model_info": model_info |
| } |
| except Exception as e: |
| logger.error(f"Model health check failed: {str(e)}") |
| return { |
| "status": "unhealthy", |
| "model_loaded": False, |
| "error": str(e) |
| } |
|
|
|
|
| |
| model_service = ModelService() |
|
|