deepfake-detection / handler.py
morhadi
Enable huggingface inference API and Use this model button via handler.py
521c54a
Raw
History Blame Contribute Delete
2.63 kB
from typing import Dict, List, Any
import torch
from PIL import Image
import io
import base64
import requests
from torchvision import transforms
class EndpointHandler():
def __init__(self, path=""):
# Preload all elements at inference
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = torch.jit.load(f"{path}/model.torchscript", map_location=self.device)
self.model.eval()
# Standard CLIP preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
try:
if isinstance(inputs, Image.Image):
image = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("http"):
response = requests.get(inputs)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
try:
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
except:
# Fallback if raw image string
image = Image.open(inputs).convert("RGB")
else:
return [{"error": "Invalid input format"}]
tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
if probs.shape[0] == 2:
real_prob = probs[0].item()
fake_prob = probs[1].item()
else:
# If there are multiple fake classes, group them
real_prob = probs[0].item()
fake_prob = probs[1:].sum().item()
prediction = [
{"label": "fake", "score": fake_prob},
{"label": "real", "score": real_prob}
]
# Sort by score for consistency
prediction = sorted(prediction, key=lambda x: x["score"], reverse=True)
return prediction
except Exception as e:
return [{"error": str(e)}]