import os import gradio as gr import torch from PIL import Image from typing import Dict, Optional, Tuple from torchvision import transforms from timm import create_model from huggingface_hub import hf_hub_download from huggingface_hub.errors import GatedRepoError, HfHubHTTPError from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError IMG_SIZE = 380 LABEL_MAPPING = {0: "human", 1: "ai"} DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform: Optional[transforms.Compose] model: Optional[torch.nn.Module] MODEL_STATUS: str try: token = os.getenv("HF_TOKEN") model_path = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="pytorch_model.pth", token=token) transform = transforms.Compose([ transforms.Resize(IMG_SIZE + 20), transforms.CenterCrop(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) model = create_model("efficientnet_b4", pretrained=False, num_classes=len(LABEL_MAPPING)) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.to(DEVICE) model.eval() MODEL_STATUS = "AI detector ready." except GatedRepoError: transform = None model = None MODEL_STATUS = ( "AI detector requires approved Hugging Face access. Configure HF_TOKEN with a permitted token." ) except (HfHubHTTPError, OSError) as exc: transform = None model = None MODEL_STATUS = f"AI detector unavailable: {exc}" except Exception as exc: # pragma: no cover - surface loading issues early. transform = None model = None MODEL_STATUS = f"AI detector failed to initialize: {exc}" AI_INTRO_SUMMARY = MODEL_STATUS if model is None else "Upload an image to view the prediction." try: TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu") TRUFOR_STATUS = TRUFOR_ENGINE.status_message except TruForUnavailableError as exc: TRUFOR_ENGINE = None TRUFOR_STATUS = str(exc) def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]: """Run the EfficientNet-based detector and return confidences with a readable summary.""" empty_scores = {label: 0.0 for label in LABEL_MAPPING.values()} if model is None or transform is None: return empty_scores, MODEL_STATUS if image is None: return empty_scores, "No image provided." image = image.convert("RGB") inputs = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(inputs) probabilities = torch.softmax(logits, dim=1)[0] ordered_scores = sorted( ((LABEL_MAPPING[idx], float(probabilities[idx])) for idx in LABEL_MAPPING), key=lambda item: item[1], reverse=True, ) scores = dict(ordered_scores) top_label, top_score = ordered_scores[0] second_label, second_score = ordered_scores[1] summary = ( f"**Predicted Label:** {top_label} \ **Confidence:** {top_score:.2%}\n" f"`{top_label}: {top_score:.2%} | {second_label}: {second_score:.2%}`" ) return scores, summary def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image]]: """Run TruFor inference when available, otherwise return diagnostics.""" if TRUFOR_ENGINE is None: return TRUFOR_STATUS, None if image is None: return "Upload an image to run TruFor.", None try: result: TruForResult = TRUFOR_ENGINE.infer(image) except TruForUnavailableError as exc: return str(exc), None # Determine if image is altered based on tamper score threshold if result.score is None: return "TruFor returned no prediction for this image.", result.map_overlay # Threshold for altered vs not altered (adjust as needed) threshold = 0.5 is_altered = result.score >= threshold prediction = "Altered" if is_altered else "Not Altered" confidence = result.score if is_altered else (1.0 - result.score) summary = f"**Prediction:** {prediction}\n**Confidence:** {confidence:.2%}" return summary, result.map_overlay def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image]]: ai_scores, ai_summary = analyze_ai_vs_human(image) trufor_summary, tamper_overlay = analyze_trufor(image) return ai_scores, ai_summary, trufor_summary, tamper_overlay with gr.Blocks() as demo: gr.Markdown( """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector.""" ) status_box = gr.Markdown(f"`TruFor: {TRUFOR_STATUS}`\n`AI Detector: {MODEL_STATUS}`") image_input = gr.Image(label="Input Image", type="pil") analyze_button = gr.Button("Analyze", variant="primary", size="sm") with gr.Tabs(): with gr.TabItem("AI vs Human"): ai_label_output = gr.Label(label="Prediction", num_top_classes=2) ai_summary_output = gr.Markdown(AI_INTRO_SUMMARY) with gr.TabItem("TruFor Forgery Detection"): trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.") tamper_overlay_output = gr.Image(label="Altered Regions Map", type="pil", interactive=False) output_components = [ ai_label_output, ai_summary_output, trufor_summary_output, tamper_overlay_output, ] analyze_button.click( fn=analyze_image, inputs=image_input, outputs=output_components, ) image_input.change( fn=analyze_image, inputs=image_input, outputs=output_components, ) if __name__ == "__main__": demo.launch()