File size: 5,736 Bytes
aa30915
 
65d7391
 
 
 
aa30915
 
 
 
65d7391
 
 
aa30915
 
 
65d7391
aa30915
 
 
65d7391
 
aa30915
 
 
 
 
 
 
 
 
 
 
65d7391
aa30915
 
 
 
 
 
 
 
 
 
 
65d7391
aa30915
 
 
 
 
65d7391
 
 
 
 
 
 
 
 
 
aa30915
 
 
 
 
 
65d7391
 
 
 
aa30915
65d7391
 
aa30915
65d7391
aa30915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65d7391
 
 
 
aa30915
65d7391
 
aa30915
65d7391
 
aa30915
65d7391
 
 
 
aa30915
65d7391
aa30915
 
 
 
 
 
 
 
 
 
 
65d7391
aa30915
65d7391
 
aa30915
65d7391
aa30915
 
65d7391
 
 
 
 
 
 
aa30915
65d7391
 
 
 
 
 
 
aa30915
65d7391
 
aa30915
65d7391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os

import gradio as gr
import torch
from PIL import Image
from typing import Dict, Optional, Tuple
from torchvision import transforms
from timm import create_model
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import GatedRepoError, HfHubHTTPError

from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError

IMG_SIZE = 380
LABEL_MAPPING = {0: "human", 1: "ai"}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform: Optional[transforms.Compose]
model: Optional[torch.nn.Module]
MODEL_STATUS: str

try:
    token = os.getenv("HF_TOKEN")
    model_path = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="pytorch_model.pth", token=token)
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE + 20),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    model = create_model("efficientnet_b4", pretrained=False, num_classes=len(LABEL_MAPPING))
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    MODEL_STATUS = "AI detector ready."
except GatedRepoError:
    transform = None
    model = None
    MODEL_STATUS = (
        "AI detector requires approved Hugging Face access. Configure HF_TOKEN with a permitted token."
    )
except (HfHubHTTPError, OSError) as exc:
    transform = None
    model = None
    MODEL_STATUS = f"AI detector unavailable: {exc}"
except Exception as exc:  # pragma: no cover - surface loading issues early.
    transform = None
    model = None
    MODEL_STATUS = f"AI detector failed to initialize: {exc}"

AI_INTRO_SUMMARY = MODEL_STATUS if model is None else "Upload an image to view the prediction."

try:
    TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu")
    TRUFOR_STATUS = TRUFOR_ENGINE.status_message
except TruForUnavailableError as exc:
    TRUFOR_ENGINE = None
    TRUFOR_STATUS = str(exc)


def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]:
    """Run the EfficientNet-based detector and return confidences with a readable summary."""
    empty_scores = {label: 0.0 for label in LABEL_MAPPING.values()}

    if model is None or transform is None:
        return empty_scores, MODEL_STATUS

    if image is None:
        return empty_scores, "No image provided."

    image = image.convert("RGB")
    inputs = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(inputs)

    probabilities = torch.softmax(logits, dim=1)[0]
    ordered_scores = sorted(
        ((LABEL_MAPPING[idx], float(probabilities[idx])) for idx in LABEL_MAPPING),
        key=lambda item: item[1],
        reverse=True,
    )
    scores = dict(ordered_scores)

    top_label, top_score = ordered_scores[0]
    second_label, second_score = ordered_scores[1]
    summary = (
        f"**Predicted Label:** {top_label}  \
**Confidence:** {top_score:.2%}\n"
        f"`{top_label}: {top_score:.2%} | {second_label}: {second_score:.2%}`"
    )

    return scores, summary


def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image]]:
    """Run TruFor inference when available, otherwise return diagnostics."""
    if TRUFOR_ENGINE is None:
        return TRUFOR_STATUS, None

    if image is None:
        return "Upload an image to run TruFor.", None

    try:
        result: TruForResult = TRUFOR_ENGINE.infer(image)
    except TruForUnavailableError as exc:
        return str(exc), None

    # Determine if image is altered based on tamper score threshold
    if result.score is None:
        return "TruFor returned no prediction for this image.", result.map_overlay
    
    # Threshold for altered vs not altered (adjust as needed)
    threshold = 0.5
    is_altered = result.score >= threshold
    prediction = "Altered" if is_altered else "Not Altered"
    confidence = result.score if is_altered else (1.0 - result.score)
    
    summary = f"**Prediction:** {prediction}\n**Confidence:** {confidence:.2%}"

    return summary, result.map_overlay


def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image]]:
    ai_scores, ai_summary = analyze_ai_vs_human(image)
    trufor_summary, tamper_overlay = analyze_trufor(image)
    return ai_scores, ai_summary, trufor_summary, tamper_overlay


with gr.Blocks() as demo:
    gr.Markdown(
        """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector."""
    )

    status_box = gr.Markdown(f"`TruFor: {TRUFOR_STATUS}`\n`AI Detector: {MODEL_STATUS}`")

    image_input = gr.Image(label="Input Image", type="pil")
    analyze_button = gr.Button("Analyze", variant="primary", size="sm")

    with gr.Tabs():
        with gr.TabItem("AI vs Human"):
            ai_label_output = gr.Label(label="Prediction", num_top_classes=2)
            ai_summary_output = gr.Markdown(AI_INTRO_SUMMARY)
        with gr.TabItem("TruFor Forgery Detection"):
            trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.")
            tamper_overlay_output = gr.Image(label="Altered Regions Map", type="pil", interactive=False)

    output_components = [
        ai_label_output,
        ai_summary_output,
        trufor_summary_output,
        tamper_overlay_output,
    ]

    analyze_button.click(
        fn=analyze_image,
        inputs=image_input,
        outputs=output_components,
    )

    image_input.change(
        fn=analyze_image,
        inputs=image_input,
        outputs=output_components,
    )


if __name__ == "__main__":
    demo.launch()