malusama commited on
Commit
e53cbe6
·
verified ·
1 Parent(s): 88726e9

Add Inference Endpoints handler

Browse files
Files changed (2) hide show
  1. README.md +27 -0
  2. handler.py +120 -0
README.md CHANGED
@@ -83,6 +83,33 @@ git commit -m "Upload M2-Encoder HF export"
83
  git push origin main
84
  ```
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ## Notes
87
 
88
  - This is a Hugging Face remote-code adapter, not a native `transformers` implementation.
 
83
  git push origin main
84
  ```
85
 
86
+ ## Inference Endpoints
87
+
88
+ This repo also includes a `handler.py` for Hugging Face Inference Endpoints custom deployments.
89
+
90
+ Example request body:
91
+
92
+ ```json
93
+ {
94
+ "inputs": {
95
+ "text": ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"],
96
+ "image": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
97
+ },
98
+ "parameters": {
99
+ "return_probs": true,
100
+ "return_logits": false
101
+ }
102
+ }
103
+ ```
104
+
105
+ Example response fields:
106
+
107
+ - `text_embedding`
108
+ - `image_embedding`
109
+ - `scores`
110
+ - `probs`
111
+ - `logits_per_image` when `return_logits=true`
112
+
113
  ## Notes
114
 
115
  - This is a Hugging Face remote-code adapter, not a native `transformers` implementation.
handler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from typing import Any, Dict, List
5
+ from urllib.parse import urlparse
6
+ from urllib.request import urlopen
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoModel, AutoProcessor
11
+
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path: str = ""):
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
17
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
+ payload = data.pop("inputs", data)
23
+ parameters = data.pop("parameters", {}) or {}
24
+
25
+ texts = self._coerce_texts(payload)
26
+ images = self._coerce_images(payload)
27
+ if not texts and not images:
28
+ raise ValueError(
29
+ "Expected `inputs` to include `text`/`texts` and/or `image`/`images`."
30
+ )
31
+
32
+ result: Dict[str, Any] = {}
33
+ with torch.no_grad():
34
+ text_embeds = None
35
+ image_embeds = None
36
+
37
+ if texts:
38
+ text_inputs = self.processor(text=texts, return_tensors="pt")
39
+ text_inputs = self._move_to_device(text_inputs)
40
+ text_embeds = self.model(**text_inputs).text_embeds
41
+ result["text_embedding"] = text_embeds.cpu().tolist()
42
+
43
+ if images:
44
+ image_inputs = self.processor(images=images, return_tensors="pt")
45
+ image_inputs = self._move_to_device(image_inputs)
46
+ image_embeds = self.model(**image_inputs).image_embeds
47
+ result["image_embedding"] = image_embeds.cpu().tolist()
48
+
49
+ if text_embeds is not None and image_embeds is not None:
50
+ scores = image_embeds @ text_embeds.t()
51
+ result["scores"] = scores.cpu().tolist()
52
+ if parameters.get("return_probs", True):
53
+ result["probs"] = scores.softmax(dim=-1).cpu().tolist()
54
+ if parameters.get("return_logits", False):
55
+ logit_scale = self.model.model.logit_scale.exp()
56
+ result["logits_per_image"] = (
57
+ (logit_scale * image_embeds @ text_embeds.t()).cpu().tolist()
58
+ )
59
+
60
+ return result
61
+
62
+ def _move_to_device(self, batch: Dict[str, Any]) -> Dict[str, Any]:
63
+ moved = {}
64
+ for key, value in batch.items():
65
+ moved[key] = value.to(self.device) if hasattr(value, "to") else value
66
+ return moved
67
+
68
+ def _coerce_texts(self, payload: Any) -> List[str]:
69
+ if isinstance(payload, str):
70
+ return [payload]
71
+ if not isinstance(payload, dict):
72
+ return []
73
+
74
+ texts = payload.get("text", payload.get("texts"))
75
+ if texts is None:
76
+ return []
77
+ if isinstance(texts, str):
78
+ return [texts]
79
+ return [str(item) for item in texts]
80
+
81
+ def _coerce_images(self, payload: Any) -> List[Image.Image]:
82
+ if not isinstance(payload, dict):
83
+ return []
84
+
85
+ images = payload.get("image", payload.get("images"))
86
+ if images is None:
87
+ return []
88
+ if not isinstance(images, (list, tuple)):
89
+ images = [images]
90
+ return [self._load_image(item) for item in images]
91
+
92
+ def _load_image(self, value: Any) -> Image.Image:
93
+ if isinstance(value, Image.Image):
94
+ return value.convert("RGB")
95
+
96
+ if isinstance(value, dict):
97
+ for key in ("data", "image", "url", "path"):
98
+ if key in value:
99
+ value = value[key]
100
+ break
101
+
102
+ if not isinstance(value, str):
103
+ raise TypeError(f"Unsupported image input type: {type(value)!r}")
104
+
105
+ if os.path.exists(value):
106
+ return Image.open(value).convert("RGB")
107
+
108
+ parsed = urlparse(value)
109
+ if parsed.scheme in ("http", "https"):
110
+ with urlopen(value) as response:
111
+ return Image.open(io.BytesIO(response.read())).convert("RGB")
112
+
113
+ if value.startswith("data:image/"):
114
+ _, encoded = value.split(",", 1)
115
+ return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
116
+
117
+ try:
118
+ return Image.open(io.BytesIO(base64.b64decode(value))).convert("RGB")
119
+ except Exception as exc:
120
+ raise ValueError("Unsupported image string. Use URL, local path, or base64.") from exc