Shianndri-Zanniko
fix deprecated issue
7392f98
import gradio as gr
import os
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
# Model definition (Same as before)
class CLIPImageClassifier(nn.Module):
def __init__(
self,
clip_model_name="openai/clip-vit-base-patch32",
num_classes=2,
freeze_backbone=False,
):
super(CLIPImageClassifier, self).__init__()
# Load pretrained CLIP model
self.clip = CLIPModel.from_pretrained(clip_model_name)
# Freeze CLIP backbone if specified
if freeze_backbone:
for param in self.clip.vision_model.parameters():
param.requires_grad = False
# Get CLIP's image embedding dimension
self.embedding_dim = self.clip.config.projection_dim
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.embedding_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_classes),
)
def forward(self, pixel_values):
# Get image features from CLIP
image_outputs = self.clip.get_image_features(pixel_values=pixel_values)
# Pass through classifier
logits = self.classifier(image_outputs)
return logits
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load Processor
print("Loading CLIP processor...")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Initialize Model
print("Initializing model architecture...")
model = CLIPImageClassifier(
clip_model_name="openai/clip-vit-base-patch32", num_classes=2
).to(device)
# --- FIX 1: Removed the trailing comma so this is a string, not a tuple ---
weights_path = "best_clip_ai_detector.pth"
print(f"Loading weights from {weights_path}...")
try:
# Load weights safely
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
print("Weights loaded successfully!")
except FileNotFoundError:
print(
f"ERROR: Could not find {weights_path}. Please ensure the file is in the same directory."
)
except Exception as e:
print(f"ERROR loading weights: {e}")
model.eval()
def predict_image(image):
if image is None:
return None
try:
# Preprocess image
inputs = processor(images=image, return_tensors="pt", padding=True)
pixel_values = inputs["pixel_values"].to(device)
# Inference
with torch.no_grad():
outputs = model(pixel_values)
probs = torch.softmax(outputs, dim=1)
# Get probabilities
real_prob = probs[0][0].item()
fake_prob = probs[0][1].item()
# Format output
return {"Real": real_prob, "AI Generated (Fake)": fake_prob}
except Exception as e:
return f"Error during prediction: {str(e)}"
# Gradio UI
if __name__ == "__main__":
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="AI Image Detector",
description="Upload an image to detect if it is Real or AI-Generated. Uses a fine-tuned CLIP model.",
theme=gr.themes.Soft(),
# --- FIX 2: Updated for Gradio 5.0+ (replaced allow_flagging) ---
flagging_mode="never",
)
demo.launch()