Jatin-tec
Enhance AI detector initialization and improve TruFor result handling
aa30915
import os
import gradio as gr
import torch
from PIL import Image
from typing import Dict, Optional, Tuple
from torchvision import transforms
from timm import create_model
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import GatedRepoError, HfHubHTTPError
from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError
IMG_SIZE = 380
LABEL_MAPPING = {0: "human", 1: "ai"}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform: Optional[transforms.Compose]
model: Optional[torch.nn.Module]
MODEL_STATUS: str
try:
token = os.getenv("HF_TOKEN")
model_path = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="pytorch_model.pth", token=token)
transform = transforms.Compose([
transforms.Resize(IMG_SIZE + 20),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model = create_model("efficientnet_b4", pretrained=False, num_classes=len(LABEL_MAPPING))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
MODEL_STATUS = "AI detector ready."
except GatedRepoError:
transform = None
model = None
MODEL_STATUS = (
"AI detector requires approved Hugging Face access. Configure HF_TOKEN with a permitted token."
)
except (HfHubHTTPError, OSError) as exc:
transform = None
model = None
MODEL_STATUS = f"AI detector unavailable: {exc}"
except Exception as exc: # pragma: no cover - surface loading issues early.
transform = None
model = None
MODEL_STATUS = f"AI detector failed to initialize: {exc}"
AI_INTRO_SUMMARY = MODEL_STATUS if model is None else "Upload an image to view the prediction."
try:
TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu")
TRUFOR_STATUS = TRUFOR_ENGINE.status_message
except TruForUnavailableError as exc:
TRUFOR_ENGINE = None
TRUFOR_STATUS = str(exc)
def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]:
"""Run the EfficientNet-based detector and return confidences with a readable summary."""
empty_scores = {label: 0.0 for label in LABEL_MAPPING.values()}
if model is None or transform is None:
return empty_scores, MODEL_STATUS
if image is None:
return empty_scores, "No image provided."
image = image.convert("RGB")
inputs = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(inputs)
probabilities = torch.softmax(logits, dim=1)[0]
ordered_scores = sorted(
((LABEL_MAPPING[idx], float(probabilities[idx])) for idx in LABEL_MAPPING),
key=lambda item: item[1],
reverse=True,
)
scores = dict(ordered_scores)
top_label, top_score = ordered_scores[0]
second_label, second_score = ordered_scores[1]
summary = (
f"**Predicted Label:** {top_label} \
**Confidence:** {top_score:.2%}\n"
f"`{top_label}: {top_score:.2%} | {second_label}: {second_score:.2%}`"
)
return scores, summary
def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image]]:
"""Run TruFor inference when available, otherwise return diagnostics."""
if TRUFOR_ENGINE is None:
return TRUFOR_STATUS, None
if image is None:
return "Upload an image to run TruFor.", None
try:
result: TruForResult = TRUFOR_ENGINE.infer(image)
except TruForUnavailableError as exc:
return str(exc), None
# Determine if image is altered based on tamper score threshold
if result.score is None:
return "TruFor returned no prediction for this image.", result.map_overlay
# Threshold for altered vs not altered (adjust as needed)
threshold = 0.5
is_altered = result.score >= threshold
prediction = "Altered" if is_altered else "Not Altered"
confidence = result.score if is_altered else (1.0 - result.score)
summary = f"**Prediction:** {prediction}\n**Confidence:** {confidence:.2%}"
return summary, result.map_overlay
def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image]]:
ai_scores, ai_summary = analyze_ai_vs_human(image)
trufor_summary, tamper_overlay = analyze_trufor(image)
return ai_scores, ai_summary, trufor_summary, tamper_overlay
with gr.Blocks() as demo:
gr.Markdown(
"""# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector."""
)
status_box = gr.Markdown(f"`TruFor: {TRUFOR_STATUS}`\n`AI Detector: {MODEL_STATUS}`")
image_input = gr.Image(label="Input Image", type="pil")
analyze_button = gr.Button("Analyze", variant="primary", size="sm")
with gr.Tabs():
with gr.TabItem("AI vs Human"):
ai_label_output = gr.Label(label="Prediction", num_top_classes=2)
ai_summary_output = gr.Markdown(AI_INTRO_SUMMARY)
with gr.TabItem("TruFor Forgery Detection"):
trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.")
tamper_overlay_output = gr.Image(label="Altered Regions Map", type="pil", interactive=False)
output_components = [
ai_label_output,
ai_summary_output,
trufor_summary_output,
tamper_overlay_output,
]
analyze_button.click(
fn=analyze_image,
inputs=image_input,
outputs=output_components,
)
image_input.change(
fn=analyze_image,
inputs=image_input,
outputs=output_components,
)
if __name__ == "__main__":
demo.launch()