Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import warnings | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| warnings.filterwarnings("ignore") | |
| # ============ MODEL DEFINITION ============ | |
| class BAILU(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv_blocks = nn.Sequential( | |
| nn.Conv2d(3, 16, kernel_size=4, stride=1, padding=0), nn.GELU(), | |
| nn.Conv2d(16, 32, kernel_size=4, stride=1, padding=0), nn.GELU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.GELU(), | |
| nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0), nn.GELU(), | |
| nn.Conv2d(128, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), | |
| nn.Conv2d(256, 256, kernel_size=4, stride=4, padding=0), nn.GELU(), | |
| nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), nn.GELU(), | |
| nn.AdaptiveAvgPool2d(1) | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Linear(256, 32), nn.GELU(), nn.Linear(32, 4) | |
| ) | |
| def forward(self, x): | |
| features = self.conv_blocks(x) | |
| features = features.view(features.size(0), -1) | |
| return self.head(features) | |
| # ============ GLOBALS ============ | |
| VAES = ['FLUX', 'FLUX2', 'SDXL', 'SD1.5'] | |
| THRESHOLD = 0.5 | |
| # ============ HUGGINGFACE REPO CONFIG ============ | |
| HF_REPO_ID = "LoliRimuru/BAILU" | |
| HF_MODEL_FILENAME = "model.pt" | |
| # ============ LOAD MODEL ============ | |
| def load_model(): | |
| """Load the pre-trained BAILU model from HuggingFace or local path.""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # FIX: Instantiate the correct model class | |
| model = BAILU().to(device) | |
| # Load from HuggingFace Hub | |
| try: | |
| print(f"π₯ Downloading model from HuggingFace: {HF_REPO_ID}") | |
| model_file = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=HF_MODEL_FILENAME, | |
| repo_type="model", | |
| local_dir="./checkpoints", | |
| local_dir_use_symlinks=False | |
| ) | |
| checkpoint = torch.load(model_file, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| print(f"β Model loaded from HuggingFace: {HF_REPO_ID}") | |
| return model, device | |
| except Exception as e: | |
| print(f"β Failed to download/load model from HuggingFace: {e}") | |
| print(" Check your internet connection and huggingface_hub installation.") | |
| return None, device | |
| # ============ INFERENCE ============ | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| """Preprocess image for model input.""" | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.CenterCrop(512), | |
| transforms.ToTensor(), | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def predict_image(model, device, image: Image.Image): | |
| """Run inference and return predictions.""" | |
| with torch.no_grad(): | |
| image_tensor = preprocess_image(image).to(device) | |
| logits = model(image_tensor) | |
| probabilities = torch.sigmoid(logits).cpu().numpy()[0] | |
| is_ai = np.any(probabilities > THRESHOLD) | |
| max_prob = np.max(probabilities) | |
| min_prob = np.min(probabilities) | |
| confidence = max_prob if is_ai else (1 - min_prob) | |
| return probabilities, is_ai, confidence | |
| # ============ GRADIO INTERFACE ============ | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| model, device = load_model() | |
| if model is None: | |
| def error_demo(image): | |
| return "β MODEL NOT LOADED", 0.0, [["ERROR", "0%", "N/A", "0%"]] | |
| interface = gr.Interface( | |
| fn=error_demo, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=[ | |
| gr.Textbox(label="Overall Verdict"), | |
| gr.Number(label="Confidence Score", precision=2), | |
| gr.Dataframe( | |
| headers=["Detector", "AI Probability", "Prediction", "Confidence"], | |
| label="Per-Model Analysis" | |
| ) | |
| ], | |
| title="BAILU AI Detection Demo", | |
| description="Model failed to load. Please check console for details." | |
| ) | |
| return interface | |
| def inference(image): | |
| if image is None: | |
| return "π€ NO IMAGE UPLOADED", 0.0, [] | |
| probs, is_ai, confidence = predict_image(model, device, image) | |
| verdict_icon = "π΄ AI GENERATED" if is_ai else "π’ HUMAN/REAL IMAGE" | |
| verdict_text = f"{verdict_icon}\n(Confidence: {confidence:.1%})" | |
| results = [] | |
| for vae_name, prob in zip(VAES, probs): | |
| prediction = "AI" if prob > THRESHOLD else "Real" | |
| conf = prob if prob > THRESHOLD else (1 - prob) | |
| status = "π¨" if prob > 0.7 else "β οΈ" if prob > 0.5 else "β " | |
| results.append([ | |
| f"{status} {vae_name}", | |
| f"{prob:.2%}", | |
| prediction, | |
| f"{conf:.1%}" | |
| ]) | |
| results.sort(key=lambda x: float(x[1].replace('%', '')), reverse=True) | |
| return verdict_text, confidence, results | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image( | |
| type="pil", | |
| label="Upload Image (PNG, JPG, WEBP)", | |
| height=400 | |
| ), | |
| outputs=[ | |
| gr.Textbox( | |
| label="π― Overall Verdict", | |
| lines=2, | |
| elem_classes="verdict-box" | |
| ), | |
| gr.Number( | |
| label="π Overall Confidence", | |
| precision=2, | |
| elem_classes="confidence-box" | |
| ), | |
| gr.Dataframe( | |
| headers=["π§ Detector", "AI Probability", "Prediction", "Confidence"], | |
| label="π Per-Model Breakdown", | |
| elem_classes="results-table", | |
| wrap=True | |
| ) | |
| ], | |
| title="BAILU AI-Generated Image Detector", | |
| description=""" | |
| ### Detect AI-generated images | |
| BAILU analyzes artifacts to identify | |
| images generated by popular diffusion models. The model checks for traces from: | |
| **π¨ FLUX.1 | π FLUX.2 | πΌοΈ SDXL | π― Stable Diffusion 1.5** | |
| **β οΈ IMPORTANT**: This is a research tool. Results should be verified by human experts | |
| for critical decisions. The model may produce false positives/negatives. | |
| """, | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .verdict-box { | |
| font-size: 24px !important; | |
| font-weight: bold !important; | |
| text-align: center !important; | |
| } | |
| .confidence-box { | |
| font-size: 20px !important; | |
| font-weight: bold !important; | |
| } | |
| .results-table { | |
| font-size: 16px !important; | |
| } | |
| .gradio-container { | |
| max-width: 1000px !important; | |
| margin: auto !important; | |
| } | |
| """ | |
| ) | |
| return interface | |
| # ============ MAIN ============ | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |