| from typing import Dict, List, Any |
| import torch |
| from PIL import Image |
| import io |
| import base64 |
| import requests |
| from torchvision import transforms |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = torch.jit.load(f"{path}/model.torchscript", map_location=self.device) |
| self.model.eval() |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711)) |
| ]) |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| inputs = data.pop("inputs", data) |
| |
| try: |
| if isinstance(inputs, Image.Image): |
| image = inputs.convert("RGB") |
| elif isinstance(inputs, str): |
| if inputs.startswith("http"): |
| response = requests.get(inputs) |
| image = Image.open(io.BytesIO(response.content)).convert("RGB") |
| else: |
| try: |
| image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB") |
| except: |
| |
| image = Image.open(inputs).convert("RGB") |
| else: |
| return [{"error": "Invalid input format"}] |
| |
| tensor = self.transform(image).unsqueeze(0).to(self.device) |
| |
| with torch.no_grad(): |
| outputs = self.model(tensor) |
| probs = torch.nn.functional.softmax(outputs, dim=1)[0] |
| |
| if probs.shape[0] == 2: |
| real_prob = probs[0].item() |
| fake_prob = probs[1].item() |
| else: |
| |
| real_prob = probs[0].item() |
| fake_prob = probs[1:].sum().item() |
| |
| prediction = [ |
| {"label": "fake", "score": fake_prob}, |
| {"label": "real", "score": real_prob} |
| ] |
| |
| |
| prediction = sorted(prediction, key=lambda x: x["score"], reverse=True) |
| return prediction |
| |
| except Exception as e: |
| return [{"error": str(e)}] |
|
|