animeaidetect / handler.py
hirawaru's picture
Add handler.py for Transformers pipeline support
cee52b6 verified
Raw
History Blame Contribute Delete
3.16 kB
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import os
import io
import base64
class EndpointHandler:
"""
Custom handler for Hugging Face Inference API.
This class loads the EfficientNet-B3 model and performs inference.
"""
def __init__(self, path=""):
# Load model architecture (EfficientNet-B3)
self.model = models.efficientnet_b3(weights=None)
self.model.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(self.model.classifier[1].in_features, 2),
)
# Load weights
model_path = os.path.join(path, "best_model.pth")
if not os.path.exists(model_path):
# In some environments, path might be the directory containing the file
model_path = os.path.join(path, "best_model.pth")
checkpoint = torch.load(model_path, map_location="cpu")
# Handle both state_dict only and full checkpoint
state_dict = checkpoint.get("model_state_dict", checkpoint)
self.model.load_state_dict(state_dict)
self.model.eval()
# Define transforms (Standard ImageNet normalization for EfficientNet)
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.labels = ["Natural", "Synthetic"]
def __call__(self, data):
"""
Args:
data (:obj:`dict`):
data contains the raw data and some parameters.
"inputs" should contain the image data.
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
# Handle image data (can be bytes, PIL, or base64 string)
if isinstance(inputs, bytes):
image = Image.open(io.BytesIO(inputs)).convert("RGB")
elif isinstance(inputs, str):
# Assume base64
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
elif isinstance(inputs, Image.Image):
image = inputs.convert("RGB")
else:
# Try to treat it as a dict or other format if needed
raise ValueError(f"Unsupported input type: {type(inputs)}")
# Preprocess
img_tensor = self.transform(image).unsqueeze(0)
# Inference
with torch.no_grad():
logits = self.model(img_tensor)
probs = torch.softmax(logits, dim=1)[0]
# Format results as a list of dicts (label, score)
results = []
for i, label in enumerate(self.labels):
results.append({
"label": label,
"score": float(probs[i])
})
# Sort by score descending (highest confidence first)
results.sort(key=lambda x: x["score"], reverse=True)
return results