Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from typing import Dict, Optional, Tuple | |
| from torchvision import transforms | |
| from timm import create_model | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.errors import GatedRepoError, HfHubHTTPError | |
| from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError | |
| IMG_SIZE = 380 | |
| LABEL_MAPPING = {0: "human", 1: "ai"} | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform: Optional[transforms.Compose] | |
| model: Optional[torch.nn.Module] | |
| MODEL_STATUS: str | |
| try: | |
| token = os.getenv("HF_TOKEN") | |
| model_path = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="pytorch_model.pth", token=token) | |
| transform = transforms.Compose([ | |
| transforms.Resize(IMG_SIZE + 20), | |
| transforms.CenterCrop(IMG_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| model = create_model("efficientnet_b4", pretrained=False, num_classes=len(LABEL_MAPPING)) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| MODEL_STATUS = "AI detector ready." | |
| except GatedRepoError: | |
| transform = None | |
| model = None | |
| MODEL_STATUS = ( | |
| "AI detector requires approved Hugging Face access. Configure HF_TOKEN with a permitted token." | |
| ) | |
| except (HfHubHTTPError, OSError) as exc: | |
| transform = None | |
| model = None | |
| MODEL_STATUS = f"AI detector unavailable: {exc}" | |
| except Exception as exc: # pragma: no cover - surface loading issues early. | |
| transform = None | |
| model = None | |
| MODEL_STATUS = f"AI detector failed to initialize: {exc}" | |
| AI_INTRO_SUMMARY = MODEL_STATUS if model is None else "Upload an image to view the prediction." | |
| try: | |
| TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu") | |
| TRUFOR_STATUS = TRUFOR_ENGINE.status_message | |
| except TruForUnavailableError as exc: | |
| TRUFOR_ENGINE = None | |
| TRUFOR_STATUS = str(exc) | |
| def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]: | |
| """Run the EfficientNet-based detector and return confidences with a readable summary.""" | |
| empty_scores = {label: 0.0 for label in LABEL_MAPPING.values()} | |
| if model is None or transform is None: | |
| return empty_scores, MODEL_STATUS | |
| if image is None: | |
| return empty_scores, "No image provided." | |
| image = image.convert("RGB") | |
| inputs = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(inputs) | |
| probabilities = torch.softmax(logits, dim=1)[0] | |
| ordered_scores = sorted( | |
| ((LABEL_MAPPING[idx], float(probabilities[idx])) for idx in LABEL_MAPPING), | |
| key=lambda item: item[1], | |
| reverse=True, | |
| ) | |
| scores = dict(ordered_scores) | |
| top_label, top_score = ordered_scores[0] | |
| second_label, second_score = ordered_scores[1] | |
| summary = ( | |
| f"**Predicted Label:** {top_label} \ | |
| **Confidence:** {top_score:.2%}\n" | |
| f"`{top_label}: {top_score:.2%} | {second_label}: {second_score:.2%}`" | |
| ) | |
| return scores, summary | |
| def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image]]: | |
| """Run TruFor inference when available, otherwise return diagnostics.""" | |
| if TRUFOR_ENGINE is None: | |
| return TRUFOR_STATUS, None | |
| if image is None: | |
| return "Upload an image to run TruFor.", None | |
| try: | |
| result: TruForResult = TRUFOR_ENGINE.infer(image) | |
| except TruForUnavailableError as exc: | |
| return str(exc), None | |
| # Determine if image is altered based on tamper score threshold | |
| if result.score is None: | |
| return "TruFor returned no prediction for this image.", result.map_overlay | |
| # Threshold for altered vs not altered (adjust as needed) | |
| threshold = 0.5 | |
| is_altered = result.score >= threshold | |
| prediction = "Altered" if is_altered else "Not Altered" | |
| confidence = result.score if is_altered else (1.0 - result.score) | |
| summary = f"**Prediction:** {prediction}\n**Confidence:** {confidence:.2%}" | |
| return summary, result.map_overlay | |
| def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image]]: | |
| ai_scores, ai_summary = analyze_ai_vs_human(image) | |
| trufor_summary, tamper_overlay = analyze_trufor(image) | |
| return ai_scores, ai_summary, trufor_summary, tamper_overlay | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector.""" | |
| ) | |
| status_box = gr.Markdown(f"`TruFor: {TRUFOR_STATUS}`\n`AI Detector: {MODEL_STATUS}`") | |
| image_input = gr.Image(label="Input Image", type="pil") | |
| analyze_button = gr.Button("Analyze", variant="primary", size="sm") | |
| with gr.Tabs(): | |
| with gr.TabItem("AI vs Human"): | |
| ai_label_output = gr.Label(label="Prediction", num_top_classes=2) | |
| ai_summary_output = gr.Markdown(AI_INTRO_SUMMARY) | |
| with gr.TabItem("TruFor Forgery Detection"): | |
| trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.") | |
| tamper_overlay_output = gr.Image(label="Altered Regions Map", type="pil", interactive=False) | |
| output_components = [ | |
| ai_label_output, | |
| ai_summary_output, | |
| trufor_summary_output, | |
| tamper_overlay_output, | |
| ] | |
| analyze_button.click( | |
| fn=analyze_image, | |
| inputs=image_input, | |
| outputs=output_components, | |
| ) | |
| image_input.change( | |
| fn=analyze_image, | |
| inputs=image_input, | |
| outputs=output_components, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |