FactSight / services /ai_image_service.py
DeepActionPotential's picture
Initial project upload via Python API for Flask Space
e0f2d0e verified
import torch
import timm
from torchvision import transforms
from PIL import Image
from typing import Any
from schemas.vision_schemas import AIImageDetector
class ENetAIImageDetector(AIImageDetector):
"""
EfficientNet-B3 AI Image Detector that classifies whether an image
is AI-generated or real using a pre-trained PyTorch model.
Attributes:
model_path (str): Path to the trained model file (.pt).
model (Any): Loaded PyTorch model.
device (str): Device to run inference on ('cuda' or 'cpu').
"""
def __init__(self, model_path: str = "./models/efficientnet_b3_full_ai_image_classifier.pt"):
"""
Initialize the ENetAIImageDetector.
Args:
model_path (str, optional): Path to the trained EfficientNet model.
"""
self.model_path = model_path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self._load_model()
self.transform = self._build_transform()
def _load_model(self) -> Any:
"""Load the trained EfficientNet-B3 model."""
if self.model_path.endswith(".pt"):
model = torch.load(self.model_path, map_location=self.device, weights_only=False)
else:
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=1)
model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model.to(self.device)
model.eval()
return model
def _build_transform(self) -> Any:
"""Return preprocessing pipeline for input images."""
return transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a normalized tensor."""
return self.transform(image).unsqueeze(0).to(self.device)
def detect(self, image: Image.Image) -> bool:
"""
Detect whether a given PIL image is AI-generated.
Args:
image (PIL.Image.Image): The input image.
Returns:
bool: True if AI-generated, False if real.
"""
if not isinstance(image, Image.Image):
raise TypeError("Input must be a PIL.Image.Image object.")
img_tensor = self._preprocess_image(image)
with torch.no_grad():
outputs = self.model(img_tensor)
prob = torch.sigmoid(outputs).item()
is_ai_generated = prob >= 0.001
return is_ai_generated