| import torch
|
| import torch.nn as nn
|
| import torchvision.models as models
|
| import torchvision.transforms as T
|
| from PIL import Image
|
| import os
|
| import io
|
| import base64
|
|
|
| class EndpointHandler:
|
| """
|
| Custom handler for Hugging Face Inference API.
|
| This class loads the EfficientNet-B3 model and performs inference.
|
| """
|
| def __init__(self, path=""):
|
|
|
| self.model = models.efficientnet_b3(weights=None)
|
| self.model.classifier = nn.Sequential(
|
| nn.Dropout(0.3),
|
| nn.Linear(self.model.classifier[1].in_features, 2),
|
| )
|
|
|
|
|
| model_path = os.path.join(path, "best_model.pth")
|
| if not os.path.exists(model_path):
|
|
|
| model_path = os.path.join(path, "best_model.pth")
|
|
|
| checkpoint = torch.load(model_path, map_location="cpu")
|
|
|
| state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| self.model.load_state_dict(state_dict)
|
| self.model.eval()
|
|
|
|
|
| self.transform = T.Compose([
|
| T.Resize((224, 224)),
|
| T.ToTensor(),
|
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| ])
|
|
|
| self.labels = ["Natural", "Synthetic"]
|
|
|
| def __call__(self, data):
|
| """
|
| Args:
|
| data (:obj:`dict`):
|
| data contains the raw data and some parameters.
|
| "inputs" should contain the image data.
|
| Return:
|
| A :obj:`list` | `dict`: will be serialized and returned
|
| """
|
| inputs = data.pop("inputs", data)
|
|
|
|
|
| if isinstance(inputs, bytes):
|
| image = Image.open(io.BytesIO(inputs)).convert("RGB")
|
| elif isinstance(inputs, str):
|
|
|
| image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
|
| elif isinstance(inputs, Image.Image):
|
| image = inputs.convert("RGB")
|
| else:
|
|
|
| raise ValueError(f"Unsupported input type: {type(inputs)}")
|
|
|
|
|
| img_tensor = self.transform(image).unsqueeze(0)
|
|
|
|
|
| with torch.no_grad():
|
| logits = self.model(img_tensor)
|
| probs = torch.softmax(logits, dim=1)[0]
|
|
|
|
|
| results = []
|
| for i, label in enumerate(self.labels):
|
| results.append({
|
| "label": label,
|
| "score": float(probs[i])
|
| })
|
|
|
|
|
| results.sort(key=lambda x: x["score"], reverse=True)
|
| return results
|
|
|