peterproofpath commited on
Commit
7a3ba2e
·
verified ·
1 Parent(s): 5ccac9e

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +108 -3
  2. handler.py +205 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -1,3 +1,108 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - object-detection
5
+ - owlv2
6
+ - zero-shot
7
+ - visual-prompting
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # OWLv2 Inference Endpoint
12
+
13
+ Custom handler for OWLv2 (Open-World Localization v2) supporting both **image-conditioned** and **text-conditioned** object detection.
14
+
15
+ ## Features
16
+
17
+ - **Image-conditioned detection**: Find objects similar to a reference image
18
+ - **Text-conditioned detection**: Find objects matching text descriptions
19
+ - **Multiple query images**: Search for several different objects at once
20
+
21
+ ## Usage
22
+
23
+ ### Image-Conditioned Detection
24
+
25
+ Find all instances of an icon/object in a target image:
26
+
27
+ ```python
28
+ import requests
29
+ import base64
30
+
31
+ API_URL = "https://your-endpoint.endpoints.huggingface.cloud"
32
+ headers = {"Authorization": "Bearer YOUR_TOKEN"}
33
+
34
+ # Load images as base64
35
+ with open("screenshot.png", "rb") as f:
36
+ target_b64 = base64.b64encode(f.read()).decode()
37
+ with open("icon.png", "rb") as f:
38
+ query_b64 = base64.b64encode(f.read()).decode()
39
+
40
+ response = requests.post(API_URL, headers=headers, json={
41
+ "inputs": {
42
+ "target_image": target_b64,
43
+ "query_image": query_b64,
44
+ "threshold": 0.5
45
+ }
46
+ })
47
+
48
+ print(response.json())
49
+ # {"detections": [{"box": [100, 200, 150, 250], "confidence": 0.92}]}
50
+ ```
51
+
52
+ ### Text-Conditioned Detection
53
+
54
+ Find objects by description:
55
+
56
+ ```python
57
+ response = requests.post(API_URL, headers=headers, json={
58
+ "inputs": {
59
+ "target_image": target_b64,
60
+ "queries": ["a play button", "a settings icon"],
61
+ "threshold": 0.1
62
+ }
63
+ })
64
+ ```
65
+
66
+ ### Multiple Query Images
67
+
68
+ Find several different objects:
69
+
70
+ ```python
71
+ response = requests.post(API_URL, headers=headers, json={
72
+ "inputs": {
73
+ "target_image": target_b64,
74
+ "query_images": [icon1_b64, icon2_b64, icon3_b64],
75
+ "threshold": 0.5
76
+ }
77
+ })
78
+ # Results include "label": "query_0", "query_1", etc.
79
+ ```
80
+
81
+ ## Parameters
82
+
83
+ | Parameter | Type | Default | Description |
84
+ |-----------|------|---------|-------------|
85
+ | `target_image` | string | required | Base64-encoded target image |
86
+ | `query_image` | string | - | Base64-encoded reference image |
87
+ | `query_images` | array | - | Multiple base64-encoded reference images |
88
+ | `queries` | array | - | Text descriptions to search for |
89
+ | `threshold` | float | 0.5 | Confidence threshold (0-1) |
90
+ | `nms_threshold` | float | 0.3 | Non-max suppression threshold |
91
+
92
+ ## Response Format
93
+
94
+ ```json
95
+ {
96
+ "detections": [
97
+ {
98
+ "box": [x1, y1, x2, y2],
99
+ "confidence": 0.95,
100
+ "label": "query_0"
101
+ }
102
+ ]
103
+ }
104
+ ```
105
+
106
+ ## Model
107
+
108
+ Uses `google/owlv2-large-patch14-ensemble` for best accuracy.
handler.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OWLv2 Custom Handler for HuggingFace Inference Endpoints
3
+
4
+ Supports:
5
+ - Image-conditioned detection (find objects similar to a reference image)
6
+ - Text-conditioned detection (find objects matching text descriptions)
7
+ """
8
+
9
+ from typing import Dict, Any, List, Union
10
+ import torch
11
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
12
+ from PIL import Image
13
+ import base64
14
+ import io
15
+
16
+
17
+ class EndpointHandler:
18
+ def __init__(self, path=""):
19
+ """Load model on endpoint startup."""
20
+ model_id = "google/owlv2-large-patch14-ensemble"
21
+
22
+ self.processor = Owlv2Processor.from_pretrained(model_id)
23
+ self.model = Owlv2ForObjectDetection.from_pretrained(model_id)
24
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ self.model = self.model.to(self.device)
26
+ self.model.eval()
27
+
28
+ print(f"OWLv2 loaded on {self.device}")
29
+
30
+ def _decode_image(self, image_data: str) -> Image.Image:
31
+ """Decode base64 image string to PIL Image."""
32
+ # Handle data URL format (e.g., "data:image/jpeg;base64,...")
33
+ if "," in image_data:
34
+ image_data = image_data.split(",")[1]
35
+
36
+ image_bytes = base64.b64decode(image_data)
37
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
38
+ return image
39
+
40
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
41
+ """
42
+ Process detection request.
43
+
44
+ === Image-Conditioned Detection ===
45
+ Find objects similar to a reference image.
46
+
47
+ Request:
48
+ {
49
+ "inputs": {
50
+ "target_image": "base64...",
51
+ "query_image": "base64...",
52
+ "threshold": 0.5,
53
+ "nms_threshold": 0.3
54
+ }
55
+ }
56
+
57
+ === Text-Conditioned Detection ===
58
+ Find objects matching text descriptions.
59
+
60
+ Request:
61
+ {
62
+ "inputs": {
63
+ "target_image": "base64...",
64
+ "queries": ["a button", "an icon"],
65
+ "threshold": 0.1
66
+ }
67
+ }
68
+
69
+ === Multiple Query Images ===
70
+ Find multiple different objects by image.
71
+
72
+ Request:
73
+ {
74
+ "inputs": {
75
+ "target_image": "base64...",
76
+ "query_images": ["base64...", "base64..."],
77
+ "threshold": 0.5,
78
+ "nms_threshold": 0.3
79
+ }
80
+ }
81
+
82
+ Response:
83
+ {
84
+ "detections": [
85
+ {"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"}
86
+ ]
87
+ }
88
+ """
89
+ try:
90
+ # Handle both {"inputs": {...}} and direct {...} format
91
+ inputs = data.get("inputs", data)
92
+
93
+ # Validate required field
94
+ if "target_image" not in inputs:
95
+ return {"error": "Missing required field: target_image"}
96
+
97
+ target_image = self._decode_image(inputs["target_image"])
98
+ threshold = float(inputs.get("threshold", 0.5))
99
+ nms_threshold = float(inputs.get("nms_threshold", 0.3))
100
+
101
+ # Route to appropriate detection method
102
+ if "query_image" in inputs:
103
+ # Single query image
104
+ query_image = self._decode_image(inputs["query_image"])
105
+ return self._detect_with_image(
106
+ target_image, [query_image], threshold, nms_threshold
107
+ )
108
+
109
+ elif "query_images" in inputs:
110
+ # Multiple query images
111
+ query_images = [
112
+ self._decode_image(img) for img in inputs["query_images"]
113
+ ]
114
+ return self._detect_with_image(
115
+ target_image, query_images, threshold, nms_threshold
116
+ )
117
+
118
+ elif "queries" in inputs:
119
+ # Text queries
120
+ return self._detect_with_text(
121
+ target_image, inputs["queries"], threshold
122
+ )
123
+
124
+ else:
125
+ return {
126
+ "error": "Provide 'query_image', 'query_images', or 'queries'"
127
+ }
128
+
129
+ except Exception as e:
130
+ return {"error": str(e)}
131
+
132
+ def _detect_with_image(
133
+ self,
134
+ target: Image.Image,
135
+ query_images: List[Image.Image],
136
+ threshold: float,
137
+ nms_threshold: float
138
+ ) -> Dict[str, Any]:
139
+ """Image-conditioned detection."""
140
+
141
+ inputs = self.processor(
142
+ images=target,
143
+ query_images=query_images,
144
+ return_tensors="pt"
145
+ )
146
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
147
+
148
+ with torch.no_grad():
149
+ outputs = self.model.image_guided_detection(**inputs)
150
+
151
+ target_sizes = torch.tensor([target.size[::-1]]) # (height, width)
152
+ results = self.processor.post_process_image_guided_detection(
153
+ outputs=outputs,
154
+ threshold=threshold,
155
+ nms_threshold=nms_threshold,
156
+ target_sizes=target_sizes
157
+ )[0]
158
+
159
+ detections = []
160
+ for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])):
161
+ det = {
162
+ "box": [round(c, 2) for c in box.tolist()],
163
+ "confidence": round(score.item(), 4)
164
+ }
165
+ # Add label if multiple query images
166
+ if len(query_images) > 1 and "labels" in results:
167
+ det["label"] = f"query_{results['labels'][i].item()}"
168
+ detections.append(det)
169
+
170
+ return {"detections": detections}
171
+
172
+ def _detect_with_text(
173
+ self,
174
+ target: Image.Image,
175
+ queries: List[str],
176
+ threshold: float
177
+ ) -> Dict[str, Any]:
178
+ """Text-conditioned detection."""
179
+
180
+ inputs = self.processor(
181
+ text=[queries],
182
+ images=target,
183
+ return_tensors="pt"
184
+ )
185
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
186
+
187
+ with torch.no_grad():
188
+ outputs = self.model(**inputs)
189
+
190
+ target_sizes = torch.tensor([target.size[::-1]])
191
+ results = self.processor.post_process_object_detection(
192
+ outputs, threshold=threshold, target_sizes=target_sizes
193
+ )[0]
194
+
195
+ detections = []
196
+ for box, score, label_idx in zip(
197
+ results["boxes"], results["scores"], results["labels"]
198
+ ):
199
+ detections.append({
200
+ "box": [round(c, 2) for c in box.tolist()],
201
+ "confidence": round(score.item(), 4),
202
+ "label": queries[label_idx.item()]
203
+ })
204
+
205
+ return {"detections": detections}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.35.0
3
+ pillow>=10.0.0