Image Classification
Transformers
Safetensors
PyTorch
English
Chinese
beit
ai-detection
ai-image-detection
deepfake-detection
fake-image-detection
ai-art-detection
stable-diffusion-detection
midjourney-detection
dall-e-detection
image-forensics
digital-art-verification
vit
computer-vision
Eval Results (legacy)
Upload handler.py with huggingface_hub
Browse files- handler.py +36 -17
handler.py
CHANGED
|
@@ -5,9 +5,9 @@ Custom handler for Hugging Face Inference API
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import json
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 10 |
-
from huggingface_hub import hf_hub_download
|
| 11 |
from io import BytesIO
|
| 12 |
|
| 13 |
|
|
@@ -20,18 +20,15 @@ class EndpointHandler:
|
|
| 20 |
|
| 21 |
# Load source metadata
|
| 22 |
try:
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
meta_path = f"{path}/source_meta.json"
|
| 26 |
-
|
| 27 |
-
try:
|
| 28 |
with open(meta_path) as f:
|
| 29 |
meta = json.load(f)
|
| 30 |
self.source_names = meta["source_names"]
|
| 31 |
self.source_is_real = meta["source_is_real"]
|
| 32 |
except Exception:
|
| 33 |
-
# Fallback
|
| 34 |
-
self.source_names =
|
| 35 |
self.source_is_real = {
|
| 36 |
"afhq": True, "celebahq": True, "coco": True, "ffhq": True,
|
| 37 |
"imagenet": True, "landscape": True, "lsun": True, "metfaces": True
|
|
@@ -39,19 +36,14 @@ class EndpointHandler:
|
|
| 39 |
|
| 40 |
def __call__(self, data):
|
| 41 |
"""Process inference request"""
|
| 42 |
-
# Handle input
|
| 43 |
if isinstance(data, dict):
|
| 44 |
-
image_data = data.get("inputs") or data.get("image")
|
| 45 |
else:
|
| 46 |
image_data = data
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 51 |
-
elif isinstance(image_data, Image.Image):
|
| 52 |
-
image = image_data.convert("RGB")
|
| 53 |
-
else:
|
| 54 |
-
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 55 |
|
| 56 |
# Inference
|
| 57 |
inputs = self.processor(image, return_tensors="pt")
|
|
@@ -81,3 +73,30 @@ class EndpointHandler:
|
|
| 81 |
"human_probability": round(human_prob, 3),
|
| 82 |
"top3_sources": top3_sources
|
| 83 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import json
|
| 8 |
+
import base64
|
| 9 |
from PIL import Image
|
| 10 |
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
|
|
|
| 11 |
from io import BytesIO
|
| 12 |
|
| 13 |
|
|
|
|
| 20 |
|
| 21 |
# Load source metadata
|
| 22 |
try:
|
| 23 |
+
import os
|
| 24 |
+
meta_path = os.path.join(path, "source_meta.json")
|
|
|
|
|
|
|
|
|
|
| 25 |
with open(meta_path) as f:
|
| 26 |
meta = json.load(f)
|
| 27 |
self.source_names = meta["source_names"]
|
| 28 |
self.source_is_real = meta["source_is_real"]
|
| 29 |
except Exception:
|
| 30 |
+
# Fallback - use model config
|
| 31 |
+
self.source_names = [self.model.config.id2label[i] for i in range(len(self.model.config.id2label))]
|
| 32 |
self.source_is_real = {
|
| 33 |
"afhq": True, "celebahq": True, "coco": True, "ffhq": True,
|
| 34 |
"imagenet": True, "landscape": True, "lsun": True, "metfaces": True
|
|
|
|
| 36 |
|
| 37 |
def __call__(self, data):
|
| 38 |
"""Process inference request"""
|
| 39 |
+
# Handle different input formats
|
| 40 |
if isinstance(data, dict):
|
| 41 |
+
image_data = data.get("inputs") or data.get("image") or data.get("data")
|
| 42 |
else:
|
| 43 |
image_data = data
|
| 44 |
|
| 45 |
+
# Convert to PIL Image
|
| 46 |
+
image = self._load_image(image_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Inference
|
| 49 |
inputs = self.processor(image, return_tensors="pt")
|
|
|
|
| 73 |
"human_probability": round(human_prob, 3),
|
| 74 |
"top3_sources": top3_sources
|
| 75 |
}
|
| 76 |
+
|
| 77 |
+
def _load_image(self, image_data):
|
| 78 |
+
"""Load image from various formats"""
|
| 79 |
+
# Already a PIL Image
|
| 80 |
+
if isinstance(image_data, Image.Image):
|
| 81 |
+
return image_data.convert("RGB")
|
| 82 |
+
|
| 83 |
+
# Bytes
|
| 84 |
+
if isinstance(image_data, bytes):
|
| 85 |
+
return Image.open(BytesIO(image_data)).convert("RGB")
|
| 86 |
+
|
| 87 |
+
# Base64 encoded string
|
| 88 |
+
if isinstance(image_data, str):
|
| 89 |
+
# Remove data URL prefix if present
|
| 90 |
+
if "base64," in image_data:
|
| 91 |
+
image_data = image_data.split("base64,")[1]
|
| 92 |
+
|
| 93 |
+
# Decode base64
|
| 94 |
+
image_bytes = base64.b64decode(image_data)
|
| 95 |
+
return Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 96 |
+
|
| 97 |
+
# List (could be from JSON)
|
| 98 |
+
if isinstance(image_data, list):
|
| 99 |
+
# Assume it's a nested structure, try first element
|
| 100 |
+
return self._load_image(image_data[0])
|
| 101 |
+
|
| 102 |
+
raise ValueError(f"Unsupported image format: {type(image_data)}")
|