File size: 3,757 Bytes
baf60bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Custom handler for Hugging Face Inference API
返回聚合后的 AI vs Human 概率
"""

import torch
import json
import base64
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from io import BytesIO


class EndpointHandler:
    def __init__(self, path=""):
        # Load model and processor
        self.model = AutoModelForImageClassification.from_pretrained(path)
        self.processor = AutoImageProcessor.from_pretrained(path)
        self.model.eval()

        # Load source metadata
        try:
            import os
            meta_path = os.path.join(path, "source_meta.json")
            with open(meta_path) as f:
                meta = json.load(f)
            self.source_names = meta["source_names"]
            self.source_is_real = meta["source_is_real"]
        except Exception:
            # Fallback - use model config
            self.source_names = [self.model.config.id2label[i] for i in range(len(self.model.config.id2label))]
            self.source_is_real = {
                "afhq": True, "celebahq": True, "coco": True, "ffhq": True,
                "imagenet": True, "landscape": True, "lsun": True, "metfaces": True
            }

    def __call__(self, data):
        """Process inference request"""
        # Handle different input formats
        if isinstance(data, dict):
            image_data = data.get("inputs") or data.get("image") or data.get("data")
        else:
            image_data = data

        # Convert to PIL Image
        image = self._load_image(image_data)

        # Inference
        inputs = self.processor(image, return_tensors="pt")

        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)[0]

        # Top-1 决定 + 置信度
        top_idx = probs.argmax().item()
        top_source = self.source_names[top_idx]
        top_confidence = probs[top_idx].item()
        is_real = self.source_is_real.get(top_source, False)

        if is_real:
            human_prob = top_confidence
            ai_prob = 1.0 - human_prob
        else:
            ai_prob = top_confidence
            human_prob = 1.0 - ai_prob

        # Get top 3 AI sources
        ai_sources = []
        for i, name in enumerate(self.source_names):
            if not self.source_is_real.get(name, False):
                ai_sources.append({"label": name, "score": round(probs[i].item(), 3)})

        ai_sources.sort(key=lambda x: x["score"], reverse=True)
        top3_sources = ai_sources[:3]

        return {
            "ai_probability": round(ai_prob, 3),
            "human_probability": round(human_prob, 3),
            "predicted_source": top_source,
            "top3_sources": top3_sources
        }

    def _load_image(self, image_data):
        """Load image from various formats"""
        # Already a PIL Image
        if isinstance(image_data, Image.Image):
            return image_data.convert("RGB")

        # Bytes
        if isinstance(image_data, bytes):
            return Image.open(BytesIO(image_data)).convert("RGB")

        # Base64 encoded string
        if isinstance(image_data, str):
            # Remove data URL prefix if present
            if "base64," in image_data:
                image_data = image_data.split("base64,")[1]

            # Decode base64
            image_bytes = base64.b64decode(image_data)
            return Image.open(BytesIO(image_bytes)).convert("RGB")

        # List (could be from JSON)
        if isinstance(image_data, list):
            # Assume it's a nested structure, try first element
            return self._load_image(image_data[0])

        raise ValueError(f"Unsupported image format: {type(image_data)}")