| from PIL import Image |
| import torch |
| import numpy as np |
| from typing import IO |
| import cv2 |
| from torchvision import transforms |
|
|
| |
| from .model_loader import models |
|
|
| class ImagePreprocessor: |
| """ |
| Handles preprocessing of images for the FFT CNN model. |
| """ |
| def __init__(self): |
| """ |
| Initializes the preprocessor. |
| """ |
| self.device = models.device |
| |
| self.transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| def process(self, image_file: IO) -> torch.Tensor: |
| """ |
| Opens an image file, applies FFT, preprocesses it, and returns a tensor. |
| |
| Args: |
| image_file (IO): The image file object (e.g., from a file upload). |
| |
| Returns: |
| torch.Tensor: The preprocessed image as a tensor, ready for the model. |
| """ |
| try: |
| |
| image_np = np.frombuffer(image_file.read(), np.uint8) |
| |
| img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE) |
| except Exception as e: |
| print(f"Error reading or decoding image: {e}") |
| raise ValueError("Invalid or corrupted image file.") |
|
|
| if img is None: |
| raise ValueError("Could not decode image. File may be empty or corrupted.") |
|
|
| |
| f = np.fft.fft2(img) |
| fshift = np.fft.fftshift(f) |
| magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) |
|
|
| |
| magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) |
| magnitude_spectrum = np.uint8(magnitude_spectrum) |
| |
| |
| image_tensor = self.transform(magnitude_spectrum) |
| |
| |
| image_tensor = image_tensor.unsqueeze(0).to(self.device) |
| |
| return image_tensor |
|
|
| |
| preprocessor = ImagePreprocessor() |
|
|