File size: 3,445 Bytes
9894d76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import json
from pathlib import Path
from transformers import CLIPProcessor
from src.model import DISCO
from PIL import Image
from typing import Tuple, Optional

MODELS_DIR = Path(__file__).parent.parent / "models"


def load_model(
    device: Optional[str] = None,
    compile_model: bool = False
) -> Tuple[DISCO, CLIPProcessor, float, dict]:
    """
    Load trained DISCO model, processor, threshold, and metadata.
    
    Args:
        device: Device to load model on (None = auto-detect)
        compile_model: Whether to compile model with torch.compile (not implemented)
    
    Returns:
        Tuple of (model, processor, threshold, metadata)
    """
    if device is None:
        device = "mps" if torch.backends.mps.is_available() else (
            "cuda" if torch.cuda.is_available() else "cpu"
        )

    # Load model
    model = DISCO.from_pretrained(MODELS_DIR)
    model = model.to(device)
    model.eval()
    
    # Load processor
    processor = CLIPProcessor.from_pretrained(MODELS_DIR)
    
    # Load metadata for threshold and other info
    metadata_path = MODELS_DIR / "model_metadata.json"
    if not metadata_path.exists():
        raise FileNotFoundError(
            f"Model metadata not found at {metadata_path}. "
            "Please run 'python src/train.py' first."
        )
    
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    
    threshold = metadata.get("threshold", 0.5)
    
    # Store device for easy access
    model._device = device
    
    if compile_model:
        # Future: could use torch.compile here if needed
        pass
    
    return model, processor, threshold, metadata


# Lazy-loaded default model (loaded on first use, not at import time)
_default_model = None
_default_processor = None
_default_threshold = None


def _get_default_model():
    """Lazy-load default model on first use."""
    global _default_model, _default_processor, _default_threshold
    if _default_model is None:
        _default_model, _default_processor, _default_threshold, _ = load_model()
    return _default_model, _default_processor, _default_threshold


def run_DISCO(
    image_path: str,
    model: Optional[DISCO] = None,
    processor: Optional[CLIPProcessor] = None,
    threshold: Optional[float] = None
) -> float:
    """
    Run DISCO inference on a single image.
    
    Args:
        image_path: Path to image file
        model: DISCO model (uses default if None)
        processor: CLIPProcessor (uses default if None)
        threshold: Classification threshold (uses model default if None)
    
    Returns:
        Probability of SUGGESTIVE class (0-1)
    """
    # Use defaults if not provided
    if model is None or processor is None:
        default_model, default_processor, default_threshold = _get_default_model()
        model = model or default_model
        processor = processor or default_processor
        threshold = threshold or default_threshold
    
    # Get device from model
    device = getattr(model, '_device', next(model.parameters()).device)
    
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Run inference
    with torch.no_grad():
        logits = model(**inputs)
        proba = torch.softmax(logits, dim=-1)[0, 1].item()
    
    return proba