boluobobo's picture
Add custom handler for Inference API
baf60bd verified
"""
Custom handler for Hugging Face Inference API
返回聚合后的 AI vs Human 概率
"""
import torch
import json
import base64
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor
self.model = AutoModelForImageClassification.from_pretrained(path)
self.processor = AutoImageProcessor.from_pretrained(path)
self.model.eval()
# Load source metadata
try:
import os
meta_path = os.path.join(path, "source_meta.json")
with open(meta_path) as f:
meta = json.load(f)
self.source_names = meta["source_names"]
self.source_is_real = meta["source_is_real"]
except Exception:
# Fallback - use model config
self.source_names = [self.model.config.id2label[i] for i in range(len(self.model.config.id2label))]
self.source_is_real = {
"afhq": True, "celebahq": True, "coco": True, "ffhq": True,
"imagenet": True, "landscape": True, "lsun": True, "metfaces": True
}
def __call__(self, data):
"""Process inference request"""
# Handle different input formats
if isinstance(data, dict):
image_data = data.get("inputs") or data.get("image") or data.get("data")
else:
image_data = data
# Convert to PIL Image
image = self._load_image(image_data)
# Inference
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
# Top-1 决定 + 置信度
top_idx = probs.argmax().item()
top_source = self.source_names[top_idx]
top_confidence = probs[top_idx].item()
is_real = self.source_is_real.get(top_source, False)
if is_real:
human_prob = top_confidence
ai_prob = 1.0 - human_prob
else:
ai_prob = top_confidence
human_prob = 1.0 - ai_prob
# Get top 3 AI sources
ai_sources = []
for i, name in enumerate(self.source_names):
if not self.source_is_real.get(name, False):
ai_sources.append({"label": name, "score": round(probs[i].item(), 3)})
ai_sources.sort(key=lambda x: x["score"], reverse=True)
top3_sources = ai_sources[:3]
return {
"ai_probability": round(ai_prob, 3),
"human_probability": round(human_prob, 3),
"predicted_source": top_source,
"top3_sources": top3_sources
}
def _load_image(self, image_data):
"""Load image from various formats"""
# Already a PIL Image
if isinstance(image_data, Image.Image):
return image_data.convert("RGB")
# Bytes
if isinstance(image_data, bytes):
return Image.open(BytesIO(image_data)).convert("RGB")
# Base64 encoded string
if isinstance(image_data, str):
# Remove data URL prefix if present
if "base64," in image_data:
image_data = image_data.split("base64,")[1]
# Decode base64
image_bytes = base64.b64decode(image_data)
return Image.open(BytesIO(image_bytes)).convert("RGB")
# List (could be from JSON)
if isinstance(image_data, list):
# Assume it's a nested structure, try first element
return self._load_image(image_data[0])
raise ValueError(f"Unsupported image format: {type(image_data)}")