from typing import Dict, List, Any import torch from PIL import Image import io import base64 import requests from torchvision import transforms class EndpointHandler(): def __init__(self, path=""): # Preload all elements at inference self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = torch.jit.load(f"{path}/model.torchscript", map_location=self.device) self.model.eval() # Standard CLIP preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.pop("inputs", data) try: if isinstance(inputs, Image.Image): image = inputs.convert("RGB") elif isinstance(inputs, str): if inputs.startswith("http"): response = requests.get(inputs) image = Image.open(io.BytesIO(response.content)).convert("RGB") else: try: image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB") except: # Fallback if raw image string image = Image.open(inputs).convert("RGB") else: return [{"error": "Invalid input format"}] tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(tensor) probs = torch.nn.functional.softmax(outputs, dim=1)[0] if probs.shape[0] == 2: real_prob = probs[0].item() fake_prob = probs[1].item() else: # If there are multiple fake classes, group them real_prob = probs[0].item() fake_prob = probs[1:].sum().item() prediction = [ {"label": "fake", "score": fake_prob}, {"label": "real", "score": real_prob} ] # Sort by score for consistency prediction = sorted(prediction, key=lambda x: x["score"], reverse=True) return prediction except Exception as e: return [{"error": str(e)}]