Forrest Wargo
commited on
Commit
·
416a2e8
1
Parent(s):
6600256
Accept OpenAI-style input; return annotated_image+boxes
Browse files- handler.py +75 -17
handler.py
CHANGED
|
@@ -49,17 +49,48 @@ class EndpointHandler:
|
|
| 49 |
self.annotator = BoxAnnotator()
|
| 50 |
|
| 51 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 52 |
-
#
|
| 53 |
-
#
|
| 54 |
-
#
|
| 55 |
-
#
|
| 56 |
-
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
data
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
ocr_texts, ocr_bboxes = self.check_ocr_bboxes(
|
| 65 |
image,
|
|
@@ -68,17 +99,44 @@ class EndpointHandler:
|
|
| 68 |
)
|
| 69 |
annotated_image, filtered_bboxes_out = self.get_som_labeled_img(
|
| 70 |
image,
|
| 71 |
-
image_size=
|
| 72 |
ocr_texts=ocr_texts,
|
| 73 |
ocr_bboxes=ocr_bboxes,
|
| 74 |
-
bbox_threshold=
|
| 75 |
-
iou_threshold=
|
| 76 |
)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
}
|
| 81 |
|
|
|
|
|
|
|
| 82 |
def check_ocr_bboxes(
|
| 83 |
self,
|
| 84 |
image: ImageType,
|
|
|
|
| 49 |
self.annotator = BoxAnnotator()
|
| 50 |
|
| 51 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 52 |
+
# Flexible input contract:
|
| 53 |
+
# 1) OpenAI-style: { messages: [ { role:'user', content:[ {type:'image_url', image_url:{url:data-url}}, {type:'text', text:'...'} ] } ] }
|
| 54 |
+
# 2) Legacy HF inputs: { inputs: { image: <url|data-url> } }
|
| 55 |
+
# 3) PropParse-style: { inputs: { image_b64: <data-url>, ... } }
|
| 56 |
+
|
| 57 |
+
# Normalize payload
|
| 58 |
+
payload: Dict[str, Any]
|
| 59 |
+
if isinstance(data, dict) and "inputs" in data:
|
| 60 |
+
payload = data.get("inputs") or {}
|
| 61 |
+
else:
|
| 62 |
+
payload = data
|
| 63 |
+
|
| 64 |
+
# Extract image source
|
| 65 |
+
img_source: Optional[str] = None
|
| 66 |
+
if "image_b64" in payload and isinstance(payload["image_b64"], str):
|
| 67 |
+
img_source = payload["image_b64"]
|
| 68 |
+
elif "image" in payload and isinstance(payload["image"], str):
|
| 69 |
+
img_source = payload["image"]
|
| 70 |
+
elif isinstance(payload.get("messages"), list):
|
| 71 |
+
for msg in payload["messages"]:
|
| 72 |
+
if isinstance(msg, Dict) and msg.get("role") == "user":
|
| 73 |
+
for part in msg.get("content", []):
|
| 74 |
+
if part.get("type") == "image_url":
|
| 75 |
+
img_source = part.get("image_url", {}).get("url")
|
| 76 |
+
break
|
| 77 |
+
if img_source:
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
if not img_source:
|
| 81 |
+
return {"error": "No image provided (image/image_b64/messages)"}
|
| 82 |
+
|
| 83 |
+
# Load image from data URL or external URL
|
| 84 |
+
try:
|
| 85 |
+
if isinstance(img_source, str) and img_source.startswith("data:"):
|
| 86 |
+
header, b64data = img_source.split(",", 1)
|
| 87 |
+
decoded = base64.b64decode(b64data)
|
| 88 |
+
image = Image.open(io.BytesIO(decoded))
|
| 89 |
+
image.load()
|
| 90 |
+
else:
|
| 91 |
+
image = load_image(img_source)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return {"error": f"Failed to load image: {e}"}
|
| 94 |
|
| 95 |
ocr_texts, ocr_bboxes = self.check_ocr_bboxes(
|
| 96 |
image,
|
|
|
|
| 99 |
)
|
| 100 |
annotated_image, filtered_bboxes_out = self.get_som_labeled_img(
|
| 101 |
image,
|
| 102 |
+
image_size=payload.get("image_size", None),
|
| 103 |
ocr_texts=ocr_texts,
|
| 104 |
ocr_bboxes=ocr_bboxes,
|
| 105 |
+
bbox_threshold=payload.get("bbox_threshold", 0.05),
|
| 106 |
+
iou_threshold=payload.get("iou_threshold", None),
|
| 107 |
)
|
| 108 |
+
|
| 109 |
+
# Legacy fields
|
| 110 |
+
legacy = {"image": annotated_image, "bboxes": filtered_bboxes_out}
|
| 111 |
+
|
| 112 |
+
# PropParse-style fields
|
| 113 |
+
try:
|
| 114 |
+
w, h = image.size # type: ignore
|
| 115 |
+
except Exception:
|
| 116 |
+
w, h = None, None
|
| 117 |
+
annotated_data_url = (
|
| 118 |
+
f"data:image/png;base64,{annotated_image}"
|
| 119 |
+
if isinstance(annotated_image, str) and not annotated_image.startswith("data:")
|
| 120 |
+
else annotated_image
|
| 121 |
+
)
|
| 122 |
+
elements = [
|
| 123 |
+
{
|
| 124 |
+
"type": box.get("type", "icon"),
|
| 125 |
+
"bbox_xyxy_norm": box.get("bbox"),
|
| 126 |
+
"interactivity": box.get("interactivity", True),
|
| 127 |
+
"content": box.get("content"),
|
| 128 |
+
}
|
| 129 |
+
for box in filtered_bboxes_out
|
| 130 |
+
if isinstance(box.get("bbox"), list)
|
| 131 |
+
]
|
| 132 |
+
propparse_style = {
|
| 133 |
+
"annotated_image": annotated_data_url,
|
| 134 |
+
"boxes": {"elements": elements},
|
| 135 |
+
**({"width": w, "height": h} if w and h else {}),
|
| 136 |
}
|
| 137 |
|
| 138 |
+
return {**legacy, **propparse_style}
|
| 139 |
+
|
| 140 |
def check_ocr_bboxes(
|
| 141 |
self,
|
| 142 |
image: ImageType,
|