DISCO-v0.1 / src /inference.py
younissk's picture
Upload folder using huggingface_hub
9894d76 verified
import torch
import json
from pathlib import Path
from transformers import CLIPProcessor
from src.model import DISCO
from PIL import Image
from typing import Tuple, Optional
MODELS_DIR = Path(__file__).parent.parent / "models"
def load_model(
device: Optional[str] = None,
compile_model: bool = False
) -> Tuple[DISCO, CLIPProcessor, float, dict]:
"""
Load trained DISCO model, processor, threshold, and metadata.
Args:
device: Device to load model on (None = auto-detect)
compile_model: Whether to compile model with torch.compile (not implemented)
Returns:
Tuple of (model, processor, threshold, metadata)
"""
if device is None:
device = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
# Load model
model = DISCO.from_pretrained(MODELS_DIR)
model = model.to(device)
model.eval()
# Load processor
processor = CLIPProcessor.from_pretrained(MODELS_DIR)
# Load metadata for threshold and other info
metadata_path = MODELS_DIR / "model_metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(
f"Model metadata not found at {metadata_path}. "
"Please run 'python src/train.py' first."
)
with open(metadata_path, "r") as f:
metadata = json.load(f)
threshold = metadata.get("threshold", 0.5)
# Store device for easy access
model._device = device
if compile_model:
# Future: could use torch.compile here if needed
pass
return model, processor, threshold, metadata
# Lazy-loaded default model (loaded on first use, not at import time)
_default_model = None
_default_processor = None
_default_threshold = None
def _get_default_model():
"""Lazy-load default model on first use."""
global _default_model, _default_processor, _default_threshold
if _default_model is None:
_default_model, _default_processor, _default_threshold, _ = load_model()
return _default_model, _default_processor, _default_threshold
def run_DISCO(
image_path: str,
model: Optional[DISCO] = None,
processor: Optional[CLIPProcessor] = None,
threshold: Optional[float] = None
) -> float:
"""
Run DISCO inference on a single image.
Args:
image_path: Path to image file
model: DISCO model (uses default if None)
processor: CLIPProcessor (uses default if None)
threshold: Classification threshold (uses model default if None)
Returns:
Probability of SUGGESTIVE class (0-1)
"""
# Use defaults if not provided
if model is None or processor is None:
default_model, default_processor, default_threshold = _get_default_model()
model = model or default_model
processor = processor or default_processor
threshold = threshold or default_threshold
# Get device from model
device = getattr(model, '_device', next(model.parameters()).device)
# Load and preprocess image
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
logits = model(**inputs)
proba = torch.softmax(logits, dim=-1)[0, 1].item()
return proba