Rohan Kumar Shah
added real and forgery detection model
b7c5baf
raw
history blame
2.32 kB
from PIL import Image
import torch
import numpy as np
from typing import IO
import cv2
from torchvision import transforms
# Import the globally loaded models instance
from model_loader import models
class ImagePreprocessor:
"""
Handles preprocessing of images for the FFT CNN model.
"""
def __init__(self):
"""
Initializes the preprocessor.
"""
self.device = models.device
# Define the image transformations, matching the training process
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def process(self, image_file: IO) -> torch.Tensor:
"""
Opens an image file, applies FFT, preprocesses it, and returns a tensor.
Args:
image_file (IO): The image file object (e.g., from a file upload).
Returns:
torch.Tensor: The preprocessed image as a tensor, ready for the model.
"""
try:
# Read the image file into a numpy array
image_np = np.frombuffer(image_file.read(), np.uint8)
# Decode the image as grayscale
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
except Exception as e:
print(f"Error reading or decoding image: {e}")
raise ValueError("Invalid or corrupted image file.")
if img is None:
raise ValueError("Could not decode image. File may be empty or corrupted.")
# 1. Apply Fast Fourier Transform (FFT)
f = np.fft.fft2(img)
fshift = np.fft.fftshift(f)
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
# Normalize the magnitude spectrum to be in the range [0, 255]
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
magnitude_spectrum = np.uint8(magnitude_spectrum)
# 2. Apply torchvision transforms
image_tensor = self.transform(magnitude_spectrum)
# Add a batch dimension and move to the correct device
image_tensor = image_tensor.unsqueeze(0).to(self.device)
return image_tensor
# Create a single instance of the preprocessor
preprocessor = ImagePreprocessor()