File size: 2,633 Bytes
521c54a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from typing import Dict, List, Any
import torch
from PIL import Image
import io
import base64
import requests
from torchvision import transforms
class EndpointHandler():
def __init__(self, path=""):
# Preload all elements at inference
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = torch.jit.load(f"{path}/model.torchscript", map_location=self.device)
self.model.eval()
# Standard CLIP preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
try:
if isinstance(inputs, Image.Image):
image = inputs.convert("RGB")
elif isinstance(inputs, str):
if inputs.startswith("http"):
response = requests.get(inputs)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
try:
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
except:
# Fallback if raw image string
image = Image.open(inputs).convert("RGB")
else:
return [{"error": "Invalid input format"}]
tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
if probs.shape[0] == 2:
real_prob = probs[0].item()
fake_prob = probs[1].item()
else:
# If there are multiple fake classes, group them
real_prob = probs[0].item()
fake_prob = probs[1:].sum().item()
prediction = [
{"label": "fake", "score": fake_prob},
{"label": "real", "score": real_prob}
]
# Sort by score for consistency
prediction = sorted(prediction, key=lambda x: x["score"], reverse=True)
return prediction
except Exception as e:
return [{"error": str(e)}]
|