import torch import json from pathlib import Path from transformers import CLIPProcessor from src.model import DISCO from PIL import Image from typing import Tuple, Optional MODELS_DIR = Path(__file__).parent.parent / "models" def load_model( device: Optional[str] = None, compile_model: bool = False ) -> Tuple[DISCO, CLIPProcessor, float, dict]: """ Load trained DISCO model, processor, threshold, and metadata. Args: device: Device to load model on (None = auto-detect) compile_model: Whether to compile model with torch.compile (not implemented) Returns: Tuple of (model, processor, threshold, metadata) """ if device is None: device = "mps" if torch.backends.mps.is_available() else ( "cuda" if torch.cuda.is_available() else "cpu" ) # Load model model = DISCO.from_pretrained(MODELS_DIR) model = model.to(device) model.eval() # Load processor processor = CLIPProcessor.from_pretrained(MODELS_DIR) # Load metadata for threshold and other info metadata_path = MODELS_DIR / "model_metadata.json" if not metadata_path.exists(): raise FileNotFoundError( f"Model metadata not found at {metadata_path}. " "Please run 'python src/train.py' first." ) with open(metadata_path, "r") as f: metadata = json.load(f) threshold = metadata.get("threshold", 0.5) # Store device for easy access model._device = device if compile_model: # Future: could use torch.compile here if needed pass return model, processor, threshold, metadata # Lazy-loaded default model (loaded on first use, not at import time) _default_model = None _default_processor = None _default_threshold = None def _get_default_model(): """Lazy-load default model on first use.""" global _default_model, _default_processor, _default_threshold if _default_model is None: _default_model, _default_processor, _default_threshold, _ = load_model() return _default_model, _default_processor, _default_threshold def run_DISCO( image_path: str, model: Optional[DISCO] = None, processor: Optional[CLIPProcessor] = None, threshold: Optional[float] = None ) -> float: """ Run DISCO inference on a single image. Args: image_path: Path to image file model: DISCO model (uses default if None) processor: CLIPProcessor (uses default if None) threshold: Classification threshold (uses model default if None) Returns: Probability of SUGGESTIVE class (0-1) """ # Use defaults if not provided if model is None or processor is None: default_model, default_processor, default_threshold = _get_default_model() model = model or default_model processor = processor or default_processor threshold = threshold or default_threshold # Get device from model device = getattr(model, '_device', next(model.parameters()).device) # Load and preprocess image image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Run inference with torch.no_grad(): logits = model(**inputs) proba = torch.softmax(logits, dim=-1)[0, 1].item() return proba