|
|
from typing import Dict, List, Any |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import torch |
|
|
import base64 |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
from models import DeepfakeDetector |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.device = device |
|
|
self.model = DeepfakeDetector(pretrained=False) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
state_dict = load_file(f"{path}/best_model.safetensors") |
|
|
self.model.load_state_dict(state_dict, strict=False) |
|
|
except Exception as e: |
|
|
print(f"Error loading weights: {e}") |
|
|
|
|
|
state_dict = load_file("best_model.safetensors") |
|
|
self.model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.transform = A.Compose([ |
|
|
A.Resize(224, 224), |
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
def __call__(self, data: Any) -> List[Dict[str, Any]]: |
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
|
|
|
image = None |
|
|
if isinstance(inputs, Image.Image): |
|
|
image = inputs |
|
|
elif isinstance(inputs, str): |
|
|
|
|
|
try: |
|
|
if "base64," in inputs: |
|
|
inputs = inputs.split("base64,")[1] |
|
|
image_bytes = base64.b64decode(inputs) |
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
except: |
|
|
|
|
|
pass |
|
|
elif isinstance(inputs, bytes): |
|
|
image = Image.open(BytesIO(inputs)) |
|
|
|
|
|
if image is None: |
|
|
return [{"error": "Invalid input format"}] |
|
|
|
|
|
image = image.convert("RGB") |
|
|
image_np = np.array(image) |
|
|
|
|
|
|
|
|
augmented = self.transform(image=image_np) |
|
|
image_tensor = augmented['image'].unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(image_tensor) |
|
|
prob = torch.sigmoid(output).item() |
|
|
|
|
|
label = "FAKE" if prob > 0.5 else "REAL" |
|
|
score = prob if prob > 0.5 else 1 - prob |
|
|
|
|
|
return [{"label": label, "score": score}] |
|
|
|