| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| |
| from model_loader import models |
|
|
| class Interferencer: |
| """ |
| Performs inference using the FFT CNN model. |
| """ |
| def __init__(self): |
| """ |
| Initializes the interferencer with the loaded model. |
| """ |
| self.fft_model = models.fft_model |
|
|
| @torch.no_grad() |
| def predict(self, image_tensor: torch.Tensor) -> dict: |
| """ |
| Takes a preprocessed image tensor and returns the classification result. |
| |
| Args: |
| image_tensor (torch.Tensor): The preprocessed image tensor. |
| |
| Returns: |
| dict: A dictionary containing the classification label and confidence score. |
| """ |
| |
| outputs = self.fft_model(image_tensor) |
| |
| |
| probabilities = F.softmax(outputs, dim=1) |
| |
| |
| confidence, predicted_idx = torch.max(probabilities, 1) |
| |
| prediction = predicted_idx.item() |
| |
| |
| |
| |
| label_map = {0: 'fake', 1: 'real'} |
| classification_label = label_map.get(prediction, "unknown") |
|
|
| return { |
| "classification": classification_label, |
| "confidence": confidence.item() |
| } |
|
|
| |
| interferencer = Interferencer() |
|
|