import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import numpy as np import warnings from huggingface_hub import hf_hub_download import os warnings.filterwarnings("ignore") # ============ MODEL DEFINITION ============ class BAILU(nn.Module): def __init__(self): super().__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), nn.AdaptiveAvgPool2d(1) ) self.head = nn.Sequential( nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4) ) def forward(self, x): features = self.conv_blocks(x) features = features.view(features.size(0), -1) return self.head(features) # ============ GLOBALS ============ VAES = ['FLUX', 'FLUX2', 'SDXL', 'SD1.5'] THRESHOLD = 0.5 # ============ HUGGINGFACE REPO CONFIG ============ HF_REPO_ID = "LoliRimuru/BAILU" HF_MODEL_FILENAME = "model.pt" # ============ LOAD MODEL ============ def load_model(): """Load the pre-trained BAILU model from HuggingFace or local path.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX: Instantiate the correct model class model = BAILU().to(device) # Load from HuggingFace Hub try: print(f"📥 Downloading model from HuggingFace: {HF_REPO_ID}") model_file = hf_hub_download( repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME, repo_type="model", local_dir="./checkpoints", local_dir_use_symlinks=False ) checkpoint = torch.load(model_file, map_location=device, weights_only=True) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print(f"✅ Model loaded from HuggingFace: {HF_REPO_ID}") return model, device except Exception as e: print(f"❌ Failed to download/load model from HuggingFace: {e}") print(" Check your internet connection and huggingface_hub installation.") return None, device # ============ INFERENCE ============ def preprocess_image(image: Image.Image) -> torch.Tensor: """Preprocess image for model input.""" if image.mode != "RGB": image = image.convert("RGB") transform = transforms.Compose([ transforms.CenterCrop(512), transforms.ToTensor(), ]) return transform(image).unsqueeze(0) def predict_image(model, device, image: Image.Image): """Run inference and return predictions.""" with torch.no_grad(): image_tensor = preprocess_image(image).to(device) logits = model(image_tensor) probabilities = torch.sigmoid(logits).cpu().numpy()[0] is_ai = np.any(probabilities > THRESHOLD) max_prob = np.max(probabilities) min_prob = np.min(probabilities) confidence = max_prob if is_ai else (1 - min_prob) return probabilities, is_ai, confidence # ============ GRADIO INTERFACE ============ def create_demo(): """Create Gradio interface.""" model, device = load_model() if model is None: def error_demo(image): return "❌ MODEL NOT LOADED", 0.0, [["ERROR", "0%", "N/A", "0%"]] interface = gr.Interface( fn=error_demo, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Textbox(label="Overall Verdict"), gr.Number(label="Confidence Score", precision=2), gr.Dataframe( headers=["Detector", "AI Probability", "Prediction", "Confidence"], label="Per-Model Analysis" ) ], title="BAILU AI Detection Demo", description="Model failed to load. Please check console for details." ) return interface def inference(image): if image is None: return "🤔 NO IMAGE UPLOADED", 0.0, [] probs, is_ai, confidence = predict_image(model, device, image) verdict_icon = "🔴 AI GENERATED" if is_ai else "🟢 HUMAN/REAL IMAGE" verdict_text = f"{verdict_icon}\n(Confidence: {confidence:.1%})" results = [] for vae_name, prob in zip(VAES, probs): prediction = "AI" if prob > THRESHOLD else "Real" conf = prob if prob > THRESHOLD else (1 - prob) status = "🚨" if prob > 0.7 else "⚠️" if prob > 0.5 else "✅" results.append([ f"{status} {vae_name}", f"{prob:.2%}", prediction, f"{conf:.1%}" ]) results.sort(key=lambda x: float(x[1].replace('%', '')), reverse=True) return verdict_text, confidence, results interface = gr.Interface( fn=inference, inputs=gr.Image( type="pil", label="Upload Image (PNG, JPG, WEBP)", height=400 ), outputs=[ gr.Textbox( label="🎯 Overall Verdict", lines=2, elem_classes="verdict-box" ), gr.Number( label="📊 Overall Confidence", precision=2, elem_classes="confidence-box" ), gr.Dataframe( headers=["🧠 Detector", "AI Probability", "Prediction", "Confidence"], label="🔍 Per-Model Breakdown", elem_classes="results-table", wrap=True ) ], title="BAILU AI-Generated Image Detector", description=""" ### Detect AI-generated images BAILU analyzes artifacts to identify images generated by popular diffusion models. The model checks for traces from: **🎨 FLUX.1 | 🚀 FLUX.2 | 🖼️ SDXL | 🎯 Stable Diffusion 1.5** **⚠️ IMPORTANT**: This is a research tool. Results should be verified by human experts for critical decisions. The model may produce false positives/negatives. """, theme=gr.themes.Soft(), css=""" .verdict-box { font-size: 24px !important; font-weight: bold !important; text-align: center !important; } .confidence-box { font-size: 20px !important; font-weight: bold !important; } .results-table { font-size: 16px !important; } .gradio-container { max-width: 1000px !important; margin: auto !important; } """ ) return interface # ============ MAIN ============ if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860 )