FactSight / services /ai_image_service.py
DeepActionPotential's picture
Initial project upload via Python API for Flask Space
e0f2d0e verified
raw
history blame
2.75 kB
import torch
import timm
from torchvision import transforms
from PIL import Image
from typing import Any
from schemas.vision_schemas import AIImageDetector
class ENetAIImageDetector(AIImageDetector):
"""
EfficientNet-B3 AI Image Detector that classifies whether an image
is AI-generated or real using a pre-trained PyTorch model.
Attributes:
model_path (str): Path to the trained model file (.pt).
model (Any): Loaded PyTorch model.
device (str): Device to run inference on ('cuda' or 'cpu').
"""
def __init__(self, model_path: str = "./models/efficientnet_b3_full_ai_image_classifier.pt"):
"""
Initialize the ENetAIImageDetector.
Args:
model_path (str, optional): Path to the trained EfficientNet model.
"""
self.model_path = model_path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self._load_model()
self.transform = self._build_transform()
def _load_model(self) -> Any:
"""Load the trained EfficientNet-B3 model."""
if self.model_path.endswith(".pt"):
model = torch.load(self.model_path, map_location=self.device, weights_only=False)
else:
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=1)
model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model.to(self.device)
model.eval()
return model
def _build_transform(self) -> Any:
"""Return preprocessing pipeline for input images."""
return transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
"""Convert a PIL Image to a normalized tensor."""
return self.transform(image).unsqueeze(0).to(self.device)
def detect(self, image: Image.Image) -> bool:
"""
Detect whether a given PIL image is AI-generated.
Args:
image (PIL.Image.Image): The input image.
Returns:
bool: True if AI-generated, False if real.
"""
if not isinstance(image, Image.Image):
raise TypeError("Input must be a PIL.Image.Image object.")
img_tensor = self._preprocess_image(image)
with torch.no_grad():
outputs = self.model(img_tensor)
prob = torch.sigmoid(outputs).item()
is_ai_generated = prob >= 0.001
return is_ai_generated