File size: 4,560 Bytes
d581b00
 
 
 
dda3973
d581b00
3ce21dc
d581b00
3ce21dc
 
 
 
d581b00
 
3ce21dc
 
 
 
 
 
d581b00
 
 
 
 
3ce21dc
d581b00
3ce21dc
 
 
 
 
 
 
 
 
 
 
 
 
d581b00
 
 
 
3ce21dc
d581b00
3ce21dc
d581b00
3ce21dc
d581b00
 
3ce21dc
d581b00
 
 
 
 
 
 
 
3ce21dc
 
dda3973
 
3ce21dc
 
 
 
 
dda3973
 
3ce21dc
 
 
 
 
 
 
 
dda3973
 
 
3ce21dc
dda3973
3ce21dc
 
 
 
 
 
 
 
 
 
 
 
 
dda3973
 
 
 
 
 
 
 
3ce21dc
 
dda3973
 
 
 
 
3ce21dc
dda3973
 
 
 
 
 
 
d581b00
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from PIL import Image
from src.models.model import build_model
from src.data.transforms import val_transforms
from src.data.generator_loader import CLASS_NAMES

# Default paths for saved model checkpoints
MODEL_PATH = "saved_models/best_model.pth"
GENERATOR_MODEL_PATH = "saved_models/generator_model.pth"


# --- Binary Classifier (Real vs Fake) ---

def load_model(model_path=MODEL_PATH):
    """
    Loads the binary classifier (ResNet18) from a saved checkpoint.
    pretrained=False because we're loading our own trained weights, not ImageNet.
    map_location="cpu" ensures the model loads on any machine regardless of GPU availability.
    model.eval() disables dropout for deterministic inference.
    """
    model = build_model(pretrained=False)
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    return model


def predict(image_path: str, model=None):
    """
    Predicts whether an image is real or AI-generated.

    Flow:
    1. Load image and apply val_transforms (resize to 224x224, normalize)
    2. unsqueeze(0) adds batch dimension: [3, 224, 224] -> [1, 3, 224, 224]
    3. Forward pass returns raw logit (unbounded number)
    4. sigmoid converts logit to probability (0.0 to 1.0)
    5. prob >= 0.5 means AI-Generated, else Real
    6. Confidence = how far from 0.5 the probability is

    Returns dict with label, confidence percentage, and raw sigmoid score.
    """
    if model is None:
        model = load_model()

    image = Image.open(image_path).convert("RGB")
    tensor = val_transforms(image).unsqueeze(0)  # add batch dimension

    with torch.no_grad():  # disable gradient tracking for inference
        output = model(tensor)
        prob = torch.sigmoid(output).item()  # convert logit to probability

    label = "AI-Generated" if prob >= 0.5 else "Real"
    # Confidence = distance from decision boundary (0.5)
    confidence = prob if prob >= 0.5 else 1 - prob

    return {
        "label": label,
        "confidence": round(confidence * 100, 2),
        "raw_score": round(prob, 4)
    }


# --- Generator Type Classifier (4-class) ---

def load_generator_model(model_path=GENERATOR_MODEL_PATH):
    """
    Loads the 4-class generator type classifier from a saved checkpoint.
    Classes: Real, GAN, Diffusion, Other (defined in generator_loader.CLASS_NAMES)
    Handles DataParallel prefix (module.) if model was trained with multiple GPUs.
    """
    from src.models.train_generator import build_multiclass_model
    model = build_multiclass_model(num_classes=4, pretrained=False)

    state_dict = torch.load(model_path, map_location="cpu")

    # Remove 'module.' prefix added by DataParallel when training on multiple GPUs
    if any(k.startswith("module.") for k in state_dict.keys()):
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    model.eval()
    return model


def predict_generator(image_path: str, model=None):
    """
    Predicts the generator type of an image (Real, GAN, Diffusion, Other).

    Flow:
    1. Same preprocessing as binary classifier
    2. Forward pass returns 4 raw logits (one per class)
    3. softmax converts logits to probabilities summing to 1.0
    4. argmax picks the class with highest probability
    5. Returns predicted class, confidence, and all class probabilities

    Unlike binary classifier which uses sigmoid (single output),
    multi-class uses softmax (4 outputs) so probabilities sum to 100%.
    """
    if model is None:
        model = load_generator_model()

    image = Image.open(image_path).convert("RGB")
    tensor = val_transforms(image).unsqueeze(0)

    with torch.no_grad():
        output = model(tensor)
        probs = torch.softmax(output, dim=1)[0]  # convert logits to probabilities
        pred_class = probs.argmax().item()        # index of highest probability class
        confidence = probs[pred_class].item()

    return {
        "generator_type": CLASS_NAMES[pred_class],
        "confidence": round(confidence * 100, 2),
        # All 4 class probabilities for display in UI
        "class_probabilities": {
            CLASS_NAMES[i]: round(probs[i].item() * 100, 2)
            for i in range(4)
        }
    }


if __name__ == "__main__":
    import sys
    if len(sys.argv) < 2:
        print("Usage: python src/models/inference.py <image_path>")
    else:
        result = predict(sys.argv[1])
        print(f"Label: {result['label']}")
        print(f"Confidence: {result['confidence']}%")