deepfake-detector / pytorch_model.py
yaya36095's picture
Upload 8 files
eda917f verified
raw
history blame
3.58 kB
"""
PyTorch model implementation for AI image detection
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
# Define the model architecture
class AIDetectorModel(nn.Module):
def __init__(self):
super(AIDetectorModel, self).__init__()
# Define a typical CNN architecture
# Note: This is a placeholder architecture and should match your actual model architecture
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(256 * 28 * 28, 512)
self.fc2 = nn.Linear(512, 2) # 2 classes: real or AI-generated
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 256 * 28 * 28) # Flatten
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class PyTorchAIDetector:
def __init__(self, model_path='best_model_improved.pth'):
"""
Initialize the PyTorch-based AI image detector
Args:
model_path: Path to the trained model file
"""
# Check if CUDA is available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the model
self.model = AIDetectorModel()
# Load the trained weights
model_path = os.path.join(os.path.dirname(__file__), model_path)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
# Define image transformations
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def analyze_image(self, image_path):
"""
Analyze an image to detect if it's AI-generated
Args:
image_path: Path to the image
Returns:
Dictionary with analysis results
"""
try:
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get the probability of being AI-generated (assuming class 1 is AI-generated)
ai_score = probabilities[0, 1].item()
# Determine if the image is AI-generated
is_ai_generated = ai_score > 0.5
# Prepare results
results = {
"image_path": image_path,
"overall_score": float(ai_score),
"is_ai_generated": bool(is_ai_generated),
"model_type": "pytorch"
}
return results
except Exception as e:
raise ValueError(f"Failed to analyze image: {str(e)}")