boluobobo commited on
Commit
623b6fe
·
verified ·
1 Parent(s): 31968cd

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +36 -17
handler.py CHANGED
@@ -5,9 +5,9 @@ Custom handler for Hugging Face Inference API
5
 
6
  import torch
7
  import json
 
8
  from PIL import Image
9
  from transformers import AutoModelForImageClassification, AutoImageProcessor
10
- from huggingface_hub import hf_hub_download
11
  from io import BytesIO
12
 
13
 
@@ -20,18 +20,15 @@ class EndpointHandler:
20
 
21
  # Load source metadata
22
  try:
23
- meta_path = hf_hub_download(repo_id=path, filename="source_meta.json", local_files_only=True)
24
- except:
25
- meta_path = f"{path}/source_meta.json"
26
-
27
- try:
28
  with open(meta_path) as f:
29
  meta = json.load(f)
30
  self.source_names = meta["source_names"]
31
  self.source_is_real = meta["source_is_real"]
32
  except Exception:
33
- # Fallback
34
- self.source_names = list(self.model.config.id2label.values())
35
  self.source_is_real = {
36
  "afhq": True, "celebahq": True, "coco": True, "ffhq": True,
37
  "imagenet": True, "landscape": True, "lsun": True, "metfaces": True
@@ -39,19 +36,14 @@ class EndpointHandler:
39
 
40
  def __call__(self, data):
41
  """Process inference request"""
42
- # Handle input
43
  if isinstance(data, dict):
44
- image_data = data.get("inputs") or data.get("image")
45
  else:
46
  image_data = data
47
 
48
- # Load image
49
- if isinstance(image_data, bytes):
50
- image = Image.open(BytesIO(image_data)).convert("RGB")
51
- elif isinstance(image_data, Image.Image):
52
- image = image_data.convert("RGB")
53
- else:
54
- image = Image.open(BytesIO(image_data)).convert("RGB")
55
 
56
  # Inference
57
  inputs = self.processor(image, return_tensors="pt")
@@ -81,3 +73,30 @@ class EndpointHandler:
81
  "human_probability": round(human_prob, 3),
82
  "top3_sources": top3_sources
83
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import torch
7
  import json
8
+ import base64
9
  from PIL import Image
10
  from transformers import AutoModelForImageClassification, AutoImageProcessor
 
11
  from io import BytesIO
12
 
13
 
 
20
 
21
  # Load source metadata
22
  try:
23
+ import os
24
+ meta_path = os.path.join(path, "source_meta.json")
 
 
 
25
  with open(meta_path) as f:
26
  meta = json.load(f)
27
  self.source_names = meta["source_names"]
28
  self.source_is_real = meta["source_is_real"]
29
  except Exception:
30
+ # Fallback - use model config
31
+ self.source_names = [self.model.config.id2label[i] for i in range(len(self.model.config.id2label))]
32
  self.source_is_real = {
33
  "afhq": True, "celebahq": True, "coco": True, "ffhq": True,
34
  "imagenet": True, "landscape": True, "lsun": True, "metfaces": True
 
36
 
37
  def __call__(self, data):
38
  """Process inference request"""
39
+ # Handle different input formats
40
  if isinstance(data, dict):
41
+ image_data = data.get("inputs") or data.get("image") or data.get("data")
42
  else:
43
  image_data = data
44
 
45
+ # Convert to PIL Image
46
+ image = self._load_image(image_data)
 
 
 
 
 
47
 
48
  # Inference
49
  inputs = self.processor(image, return_tensors="pt")
 
73
  "human_probability": round(human_prob, 3),
74
  "top3_sources": top3_sources
75
  }
76
+
77
+ def _load_image(self, image_data):
78
+ """Load image from various formats"""
79
+ # Already a PIL Image
80
+ if isinstance(image_data, Image.Image):
81
+ return image_data.convert("RGB")
82
+
83
+ # Bytes
84
+ if isinstance(image_data, bytes):
85
+ return Image.open(BytesIO(image_data)).convert("RGB")
86
+
87
+ # Base64 encoded string
88
+ if isinstance(image_data, str):
89
+ # Remove data URL prefix if present
90
+ if "base64," in image_data:
91
+ image_data = image_data.split("base64,")[1]
92
+
93
+ # Decode base64
94
+ image_bytes = base64.b64decode(image_data)
95
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
96
+
97
+ # List (could be from JSON)
98
+ if isinstance(image_data, list):
99
+ # Assume it's a nested structure, try first element
100
+ return self._load_image(image_data[0])
101
+
102
+ raise ValueError(f"Unsupported image format: {type(image_data)}")