from typing import Dict, List, Any from io import BytesIO from PIL import Image import torch import base64 import numpy as np import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 from safetensors.torch import load_file # Import your model definition from models import DeepfakeDetector class EndpointHandler: def __init__(self, path="."): # Load model definition device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device self.model = DeepfakeDetector(pretrained=False) # Architecture only # Load weights try: # Try loading safetensors state_dict = load_file(f"{path}/best_model.safetensors") self.model.load_state_dict(state_dict, strict=False) except Exception as e: print(f"Error loading weights: {e}") # Fallback path if necessary state_dict = load_file("best_model.safetensors") self.model.load_state_dict(state_dict, strict=False) self.model.to(device) self.model.eval() # Define transform self.transform = A.Compose([ A.Resize(224, 224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) def __call__(self, data: Any) -> List[Dict[str, Any]]: inputs = data.pop("inputs", data) # Decode image image = None if isinstance(inputs, Image.Image): image = inputs elif isinstance(inputs, str): # Try base64 try: if "base64," in inputs: inputs = inputs.split("base64,")[1] image_bytes = base64.b64decode(inputs) image = Image.open(BytesIO(image_bytes)) except: # Url? pass elif isinstance(inputs, bytes): image = Image.open(BytesIO(inputs)) if image is None: return [{"error": "Invalid input format"}] image = image.convert("RGB") image_np = np.array(image) # Augmentations expect numpy array augmented = self.transform(image=image_np) image_tensor = augmented['image'].unsqueeze(0).to(self.device) # Inference with torch.no_grad(): output = self.model(image_tensor) prob = torch.sigmoid(output).item() label = "FAKE" if prob > 0.5 else "REAL" score = prob if prob > 0.5 else 1 - prob return [{"label": label, "score": score}]