Harshasnade's picture
Upload folder using huggingface_hub
d33766a verified
from typing import Dict, List, Any
from io import BytesIO
from PIL import Image
import torch
import base64
import numpy as np
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from safetensors.torch import load_file
# Import your model definition
from models import DeepfakeDetector
class EndpointHandler:
def __init__(self, path="."):
# Load model definition
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.model = DeepfakeDetector(pretrained=False) # Architecture only
# Load weights
try:
# Try loading safetensors
state_dict = load_file(f"{path}/best_model.safetensors")
self.model.load_state_dict(state_dict, strict=False)
except Exception as e:
print(f"Error loading weights: {e}")
# Fallback path if necessary
state_dict = load_file("best_model.safetensors")
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device)
self.model.eval()
# Define transform
self.transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def __call__(self, data: Any) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
# Decode image
image = None
if isinstance(inputs, Image.Image):
image = inputs
elif isinstance(inputs, str):
# Try base64
try:
if "base64," in inputs:
inputs = inputs.split("base64,")[1]
image_bytes = base64.b64decode(inputs)
image = Image.open(BytesIO(image_bytes))
except:
# Url?
pass
elif isinstance(inputs, bytes):
image = Image.open(BytesIO(inputs))
if image is None:
return [{"error": "Invalid input format"}]
image = image.convert("RGB")
image_np = np.array(image)
# Augmentations expect numpy array
augmented = self.transform(image=image_np)
image_tensor = augmented['image'].unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
output = self.model(image_tensor)
prob = torch.sigmoid(output).item()
label = "FAKE" if prob > 0.5 else "REAL"
score = prob if prob > 0.5 else 1 - prob
return [{"label": label, "score": score}]