Image Classification
Transformers
Safetensors
PyTorch
English
Chinese
beit
ai-detection
ai-image-detection
deepfake-detection
fake-image-detection
ai-art-detection
stable-diffusion-detection
midjourney-detection
dall-e-detection
flux-detection
image-forensics
digital-art-verification
vit
computer-vision
dual-head
Eval Results (legacy)
File size: 3,757 Bytes
baf60bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
Custom handler for Hugging Face Inference API
返回聚合后的 AI vs Human 概率
"""
import torch
import json
import base64
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from io import BytesIO
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor
self.model = AutoModelForImageClassification.from_pretrained(path)
self.processor = AutoImageProcessor.from_pretrained(path)
self.model.eval()
# Load source metadata
try:
import os
meta_path = os.path.join(path, "source_meta.json")
with open(meta_path) as f:
meta = json.load(f)
self.source_names = meta["source_names"]
self.source_is_real = meta["source_is_real"]
except Exception:
# Fallback - use model config
self.source_names = [self.model.config.id2label[i] for i in range(len(self.model.config.id2label))]
self.source_is_real = {
"afhq": True, "celebahq": True, "coco": True, "ffhq": True,
"imagenet": True, "landscape": True, "lsun": True, "metfaces": True
}
def __call__(self, data):
"""Process inference request"""
# Handle different input formats
if isinstance(data, dict):
image_data = data.get("inputs") or data.get("image") or data.get("data")
else:
image_data = data
# Convert to PIL Image
image = self._load_image(image_data)
# Inference
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
# Top-1 决定 + 置信度
top_idx = probs.argmax().item()
top_source = self.source_names[top_idx]
top_confidence = probs[top_idx].item()
is_real = self.source_is_real.get(top_source, False)
if is_real:
human_prob = top_confidence
ai_prob = 1.0 - human_prob
else:
ai_prob = top_confidence
human_prob = 1.0 - ai_prob
# Get top 3 AI sources
ai_sources = []
for i, name in enumerate(self.source_names):
if not self.source_is_real.get(name, False):
ai_sources.append({"label": name, "score": round(probs[i].item(), 3)})
ai_sources.sort(key=lambda x: x["score"], reverse=True)
top3_sources = ai_sources[:3]
return {
"ai_probability": round(ai_prob, 3),
"human_probability": round(human_prob, 3),
"predicted_source": top_source,
"top3_sources": top3_sources
}
def _load_image(self, image_data):
"""Load image from various formats"""
# Already a PIL Image
if isinstance(image_data, Image.Image):
return image_data.convert("RGB")
# Bytes
if isinstance(image_data, bytes):
return Image.open(BytesIO(image_data)).convert("RGB")
# Base64 encoded string
if isinstance(image_data, str):
# Remove data URL prefix if present
if "base64," in image_data:
image_data = image_data.split("base64,")[1]
# Decode base64
image_bytes = base64.b64decode(image_data)
return Image.open(BytesIO(image_bytes)).convert("RGB")
# List (could be from JSON)
if isinstance(image_data, list):
# Assume it's a nested structure, try first element
return self._load_image(image_data[0])
raise ValueError(f"Unsupported image format: {type(image_data)}")
|