|
|
import torch |
|
|
import json |
|
|
from pathlib import Path |
|
|
from transformers import CLIPProcessor |
|
|
from src.model import DISCO |
|
|
from PIL import Image |
|
|
from typing import Tuple, Optional |
|
|
|
|
|
MODELS_DIR = Path(__file__).parent.parent / "models" |
|
|
|
|
|
|
|
|
def load_model( |
|
|
device: Optional[str] = None, |
|
|
compile_model: bool = False |
|
|
) -> Tuple[DISCO, CLIPProcessor, float, dict]: |
|
|
""" |
|
|
Load trained DISCO model, processor, threshold, and metadata. |
|
|
|
|
|
Args: |
|
|
device: Device to load model on (None = auto-detect) |
|
|
compile_model: Whether to compile model with torch.compile (not implemented) |
|
|
|
|
|
Returns: |
|
|
Tuple of (model, processor, threshold, metadata) |
|
|
""" |
|
|
if device is None: |
|
|
device = "mps" if torch.backends.mps.is_available() else ( |
|
|
"cuda" if torch.cuda.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
model = DISCO.from_pretrained(MODELS_DIR) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
processor = CLIPProcessor.from_pretrained(MODELS_DIR) |
|
|
|
|
|
|
|
|
metadata_path = MODELS_DIR / "model_metadata.json" |
|
|
if not metadata_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"Model metadata not found at {metadata_path}. " |
|
|
"Please run 'python src/train.py' first." |
|
|
) |
|
|
|
|
|
with open(metadata_path, "r") as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
threshold = metadata.get("threshold", 0.5) |
|
|
|
|
|
|
|
|
model._device = device |
|
|
|
|
|
if compile_model: |
|
|
|
|
|
pass |
|
|
|
|
|
return model, processor, threshold, metadata |
|
|
|
|
|
|
|
|
|
|
|
_default_model = None |
|
|
_default_processor = None |
|
|
_default_threshold = None |
|
|
|
|
|
|
|
|
def _get_default_model(): |
|
|
"""Lazy-load default model on first use.""" |
|
|
global _default_model, _default_processor, _default_threshold |
|
|
if _default_model is None: |
|
|
_default_model, _default_processor, _default_threshold, _ = load_model() |
|
|
return _default_model, _default_processor, _default_threshold |
|
|
|
|
|
|
|
|
def run_DISCO( |
|
|
image_path: str, |
|
|
model: Optional[DISCO] = None, |
|
|
processor: Optional[CLIPProcessor] = None, |
|
|
threshold: Optional[float] = None |
|
|
) -> float: |
|
|
""" |
|
|
Run DISCO inference on a single image. |
|
|
|
|
|
Args: |
|
|
image_path: Path to image file |
|
|
model: DISCO model (uses default if None) |
|
|
processor: CLIPProcessor (uses default if None) |
|
|
threshold: Classification threshold (uses model default if None) |
|
|
|
|
|
Returns: |
|
|
Probability of SUGGESTIVE class (0-1) |
|
|
""" |
|
|
|
|
|
if model is None or processor is None: |
|
|
default_model, default_processor, default_threshold = _get_default_model() |
|
|
model = model or default_model |
|
|
processor = processor or default_processor |
|
|
threshold = threshold or default_threshold |
|
|
|
|
|
|
|
|
device = getattr(model, '_device', next(model.parameters()).device) |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs) |
|
|
proba = torch.softmax(logits, dim=-1)[0, 1].item() |
|
|
|
|
|
return proba |