lukhsaankumar's picture
Deploy DeepFake Detector API - 2026-03-07 09:12:00
df4a21a
"""
Model registry for managing loaded models.
"""
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
from app.core.config import settings
from app.core.errors import ModelNotFoundError, ModelNotLoadedError, ConfigurationError
from app.core.logging import get_logger
from app.models.wrappers.base_wrapper import BaseSubmodelWrapper, BaseFusionWrapper
from app.models.wrappers.dummy_random_wrapper import DummyRandomWrapper
from app.models.wrappers.dummy_majority_fusion_wrapper import DummyMajorityFusionWrapper
from app.models.wrappers.logreg_fusion_wrapper import LogRegFusionWrapper
# Real production wrappers
from app.models.wrappers.cnn_transfer_wrapper import CNNTransferWrapper
from app.models.wrappers.deit_distilled_wrapper import DeiTDistilledWrapper
from app.models.wrappers.vit_base_wrapper import ViTBaseWrapper
from app.models.wrappers.gradfield_cnn_wrapper import GradfieldCNNWrapper
from app.services.hf_hub_service import get_hf_hub_service
logger = get_logger(__name__)
def get_wrapper_class(config: Dict[str, Any]) -> Type[BaseSubmodelWrapper]:
"""
Select the appropriate wrapper class based on model config.
Uses architecture hints or model_type to dispatch to the correct wrapper.
Falls back to DummyRandomWrapper if no match found (useful for testing).
Args:
config: Model configuration dictionary
Returns:
Wrapper class (not instance)
"""
# Check various config fields that might indicate model type
arch = config.get("arch", "").lower()
model_type = config.get("type", "").lower()
model_class = config.get("model_class", "").lower()
model_name = config.get("model_name", "").lower()
library = config.get("library", "").lower()
# EfficientNet / CNN Transfer
if "efficientnet" in arch or "cnn-transfer" in model_type or "efficientnet" in model_name:
return CNNTransferWrapper
# DeiT Distilled
if "deit" in arch or "deit-distilled" in model_type or "deit" in model_name:
return DeiTDistilledWrapper
# ViT Base (check vit but not deit)
if (("vit" in arch or "vit" in model_name) and "deit" not in arch and "deit" not in model_name) or "vit-base" in model_type:
return ViTBaseWrapper
# Gradient Field CNN
if "gradient" in arch or "gradientnet" in model_class or "gradfield" in model_type or "gradient" in model_name:
return GradfieldCNNWrapper
# Fallback to dummy wrapper
logger.warning(f"No matching wrapper for config, using DummyRandomWrapper: {config}")
return DummyRandomWrapper
def get_fusion_wrapper_class(config: Dict[str, Any]) -> Type[BaseFusionWrapper]:
"""
Select the appropriate fusion wrapper class based on config.
Args:
config: Fusion model configuration dictionary
Returns:
Fusion wrapper class (not instance)
"""
fusion_type = config.get("type", "").lower()
# Logistic regression stacking fusion
if "probability_stacking" in fusion_type or "logreg" in fusion_type:
return LogRegFusionWrapper
# Majority vote fusion
if "majority" in fusion_type:
return DummyMajorityFusionWrapper
# Default to majority fusion
logger.warning(f"Unknown fusion type, using DummyMajorityFusionWrapper: {fusion_type}")
return DummyMajorityFusionWrapper
class ModelRegistry:
"""
Central registry for all loaded models.
Manages downloading, loading, and accessing models from Hugging Face Hub.
This is the single source of truth for model state.
"""
def __init__(self):
self._fusion: Optional[BaseFusionWrapper] = None
self._submodels: Dict[str, BaseSubmodelWrapper] = {}
self._is_loaded: bool = False
self._load_lock = asyncio.Lock()
self._hf_service = get_hf_hub_service()
@property
def is_loaded(self) -> bool:
"""Check if models are loaded."""
return self._is_loaded
async def load_from_fusion_repo(
self,
fusion_repo_id: str,
force_reload: bool = False
) -> None:
"""
Load fusion model and all submodels from a fusion repository.
This is the main entry point for loading models. It:
1. Downloads the fusion repo and reads its config.json
2. Extracts submodel repo IDs from config
3. Downloads and loads each submodel
4. Loads the fusion model
Args:
fusion_repo_id: Hugging Face repository ID for fusion model
force_reload: If True, reload even if already loaded
"""
async with self._load_lock:
if self._is_loaded and not force_reload:
logger.info("Models already loaded, skipping")
return
logger.info(f"Loading models from fusion repo: {fusion_repo_id}")
# Download fusion repo
fusion_path = await asyncio.to_thread(
self._hf_service.download_repo, fusion_repo_id
)
# Read fusion config
fusion_config = self._read_config(fusion_path)
logger.info(f"Fusion config: {fusion_config}")
# Get submodel repo IDs from config
submodel_repos = fusion_config.get("submodels", [])
if not submodel_repos:
raise ConfigurationError(
message="Fusion config does not specify any submodels",
details={"repo_id": fusion_repo_id}
)
# Download and load each submodel
for submodel_repo_id in submodel_repos:
await self._load_submodel(submodel_repo_id)
# Create and load fusion wrapper
fusion_wrapper_class = get_fusion_wrapper_class(fusion_config)
logger.info(f"Using fusion wrapper class {fusion_wrapper_class.__name__}")
self._fusion = fusion_wrapper_class(
repo_id=fusion_repo_id,
config=fusion_config,
local_path=fusion_path
)
self._fusion.load()
self._is_loaded = True
logger.info(f"Successfully loaded {len(self._submodels)} submodels and fusion model")
async def _load_submodel(self, repo_id: str) -> None:
"""
Download and load a single submodel.
Uses the config to determine the correct wrapper class.
Args:
repo_id: Hugging Face repository ID for the submodel
"""
logger.info(f"Loading submodel: {repo_id}")
# Download the repo
local_path = await asyncio.to_thread(
self._hf_service.download_repo, repo_id
)
# Read config
config = self._read_config(local_path)
# Select appropriate wrapper class based on config
wrapper_class = get_wrapper_class(config)
logger.info(f"Using wrapper class {wrapper_class.__name__} for {repo_id}")
# Create and load wrapper
wrapper = wrapper_class(
repo_id=repo_id,
config=config,
local_path=local_path
)
wrapper.load()
# Store by short name
self._submodels[wrapper.name] = wrapper
logger.info(f"Loaded submodel: {wrapper.name}")
def _read_config(self, local_path: str) -> Dict[str, Any]:
"""
Read config.json from a local model path.
Args:
local_path: Path to the downloaded model
Returns:
Configuration dictionary
"""
config_path = Path(local_path) / "config.json"
if not config_path.exists():
logger.warning(f"config.json not found at {config_path}, using empty config")
return {}
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
def list_models(self) -> List[Dict[str, Any]]:
"""
List all loaded models.
Returns:
List of model info dictionaries
"""
models = []
# Add fusion model
if self._fusion:
models.append({
"repo_id": self._fusion.repo_id,
"name": self._fusion.name,
"model_type": "fusion",
"config": self._fusion.config
})
# Add submodels
for name, wrapper in self._submodels.items():
models.append({
"repo_id": wrapper.repo_id,
"name": name,
"model_type": "submodel",
"config": wrapper.config
})
return models
def get_submodel(self, key: str) -> BaseSubmodelWrapper:
"""
Get a submodel by name or repo_id.
Args:
key: Submodel name or full repo_id
Returns:
Submodel wrapper
Raises:
ModelNotFoundError: If submodel not found
ModelNotLoadedError: If models not loaded
"""
if not self._is_loaded:
raise ModelNotLoadedError(
message="Models not loaded yet",
details={"requested_model": key}
)
# Try by name first
if key in self._submodels:
return self._submodels[key]
# Try by repo_id
for name, wrapper in self._submodels.items():
if wrapper.repo_id == key:
return wrapper
raise ModelNotFoundError(
message=f"Submodel not found: {key}",
details={
"requested_model": key,
"available_models": list(self._submodels.keys())
}
)
def get_all_submodels(self) -> Dict[str, BaseSubmodelWrapper]:
"""
Get all loaded submodels.
Returns:
Dictionary mapping name to submodel wrapper
Raises:
ModelNotLoadedError: If models not loaded
"""
if not self._is_loaded:
raise ModelNotLoadedError(message="Models not loaded yet")
return self._submodels.copy()
def get_fusion(self) -> BaseFusionWrapper:
"""
Get the fusion model.
Returns:
Fusion model wrapper
Raises:
ModelNotLoadedError: If models not loaded
"""
if not self._is_loaded or self._fusion is None:
raise ModelNotLoadedError(message="Fusion model not loaded yet")
return self._fusion
def get_submodel_names(self) -> List[str]:
"""Get list of loaded submodel names."""
return list(self._submodels.keys())
def get_fusion_repo_id(self) -> Optional[str]:
"""Get the fusion repo ID if loaded."""
return self._fusion.repo_id if self._fusion else None
# Global singleton instance
_model_registry: Optional[ModelRegistry] = None
def get_model_registry() -> ModelRegistry:
"""
Get the global model registry instance.
Returns:
ModelRegistry instance
"""
global _model_registry
if _model_registry is None:
_model_registry = ModelRegistry()
return _model_registry