Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import numpy as np | |
| # Model definition (Same as before) | |
| class CLIPImageClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| clip_model_name="openai/clip-vit-base-patch32", | |
| num_classes=2, | |
| freeze_backbone=False, | |
| ): | |
| super(CLIPImageClassifier, self).__init__() | |
| # Load pretrained CLIP model | |
| self.clip = CLIPModel.from_pretrained(clip_model_name) | |
| # Freeze CLIP backbone if specified | |
| if freeze_backbone: | |
| for param in self.clip.vision_model.parameters(): | |
| param.requires_grad = False | |
| # Get CLIP's image embedding dimension | |
| self.embedding_dim = self.clip.config.projection_dim | |
| # Classification head | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Linear(self.embedding_dim, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, num_classes), | |
| ) | |
| def forward(self, pixel_values): | |
| # Get image features from CLIP | |
| image_outputs = self.clip.get_image_features(pixel_values=pixel_values) | |
| # Pass through classifier | |
| logits = self.classifier(image_outputs) | |
| return logits | |
| # Check device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load Processor | |
| print("Loading CLIP processor...") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| # Initialize Model | |
| print("Initializing model architecture...") | |
| model = CLIPImageClassifier( | |
| clip_model_name="openai/clip-vit-base-patch32", num_classes=2 | |
| ).to(device) | |
| # --- FIX 1: Removed the trailing comma so this is a string, not a tuple --- | |
| weights_path = "best_clip_ai_detector.pth" | |
| print(f"Loading weights from {weights_path}...") | |
| try: | |
| # Load weights safely | |
| state_dict = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| print("Weights loaded successfully!") | |
| except FileNotFoundError: | |
| print( | |
| f"ERROR: Could not find {weights_path}. Please ensure the file is in the same directory." | |
| ) | |
| except Exception as e: | |
| print(f"ERROR loading weights: {e}") | |
| model.eval() | |
| def predict_image(image): | |
| if image is None: | |
| return None | |
| try: | |
| # Preprocess image | |
| inputs = processor(images=image, return_tensors="pt", padding=True) | |
| pixel_values = inputs["pixel_values"].to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| probs = torch.softmax(outputs, dim=1) | |
| # Get probabilities | |
| real_prob = probs[0][0].item() | |
| fake_prob = probs[0][1].item() | |
| # Format output | |
| return {"Real": real_prob, "AI Generated (Fake)": fake_prob} | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Gradio UI | |
| if __name__ == "__main__": | |
| demo = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), | |
| title="AI Image Detector", | |
| description="Upload an image to detect if it is Real or AI-Generated. Uses a fine-tuned CLIP model.", | |
| theme=gr.themes.Soft(), | |
| # --- FIX 2: Updated for Gradio 5.0+ (replaced allow_flagging) --- | |
| flagging_mode="never", | |
| ) | |
| demo.launch() | |