File size: 3,560 Bytes
7b474fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations

from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from hydra import initialize_config_dir
from hydra.core.global_hydra import GlobalHydra

from detectivesam_inference.checkpoint import (
    InferenceConfig,
    load_inference_config,
    resolve_checkpoint_path,
    resolve_repo_path,
)
from detectivesam_inference.dataset import PreparedSample
from detectivesam_inference.models.forgerylocalizer import ForgeryLocalizer


@dataclass(frozen=True)
class PredictionResult:
    probability: np.ndarray
    pred_mask: np.ndarray


def get_repo_root() -> Path:
    return Path(__file__).resolve().parent.parent


def select_device(device: str | None = None) -> torch.device:
    if device is not None:
        return torch.device(device)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def initialize_sam2_config(config_dir: str | Path) -> None:
    config_dir = str(Path(config_dir).resolve())
    hydra = GlobalHydra.instance()
    current_dir = getattr(initialize_sam2_config, "_current_dir", None)
    if hydra.is_initialized():
        if current_dir == config_dir:
            return
        hydra.clear()
    initialize_config_dir(config_dir=config_dir, version_base=None)
    initialize_sam2_config._current_dir = config_dir


class DetectiveSAMRunner:
    def __init__(
        self,
        checkpoint_path: str | Path | None = None,
        device: str | None = None,
    ) -> None:
        self.repo_root = get_repo_root()
        self.checkpoint_path = resolve_checkpoint_path(checkpoint_path, self.repo_root)
        self.device = select_device(device)
        self.config = load_inference_config(self.checkpoint_path)
        self.model = self._load_model()

    def _load_model(self) -> ForgeryLocalizer:
        sam_config_path = resolve_repo_path(self.config.sam_config_file, self.repo_root)
        sam_checkpoint_path = resolve_repo_path(self.config.sam_checkpoint, self.repo_root)
        initialize_sam2_config(sam_config_path.parent)

        model = ForgeryLocalizer(
            sam_config=sam_config_path.name,
            sam_checkpoint=str(sam_checkpoint_path),
            prompt_dim=self.config.prompt_dim,
            downscale=self.config.downscale,
            dropout_rate=self.config.dropout_rate,
            max_streams=self.config.max_streams,
            device=str(self.device),
        ).to(self.device)

        checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
        state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
        model.load_state_dict(state_dict)
        model.eval()
        return model

    def autocast_context(self):
        if self.device.type == "cuda":
            return torch.amp.autocast(device_type="cuda")
        return nullcontext()

    def predict_sample(
        self,
        sample: PreparedSample,
        threshold: float = 0.5,
    ) -> PredictionResult:
        orig = sample.orig.unsqueeze(0).to(self.device)
        streams = [stream.unsqueeze(0).to(self.device) for stream in sample.streams]

        with torch.inference_mode():
            with self.autocast_context():
                logits = self.model(orig, streams)

        probability = torch.sigmoid(logits).squeeze().detach().cpu().numpy()
        pred_mask = (probability > threshold).astype("uint8")
        return PredictionResult(probability=probability, pred_mask=pred_mask)