Forrest Wargo commited on
Commit
416a2e8
·
1 Parent(s): 6600256

Accept OpenAI-style input; return annotated_image+boxes

Browse files
Files changed (1) hide show
  1. handler.py +75 -17
handler.py CHANGED
@@ -49,17 +49,48 @@ class EndpointHandler:
49
  self.annotator = BoxAnnotator()
50
 
51
  def __call__(self, data: Dict[str, Any]) -> Any:
52
- # data should contain the following:
53
- # "inputs": {
54
- # "image": url/base64,
55
- # (optional) "image_size": {"w": int, "h": int},
56
- # (optional) "bbox_threshold": float,
57
- # (optional) "iou_threshold": float,
58
- # }
59
- data = data.pop("inputs")
60
-
61
- # read image from either url or base64 encoding
62
- image = load_image(data["image"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  ocr_texts, ocr_bboxes = self.check_ocr_bboxes(
65
  image,
@@ -68,17 +99,44 @@ class EndpointHandler:
68
  )
69
  annotated_image, filtered_bboxes_out = self.get_som_labeled_img(
70
  image,
71
- image_size=data.get("image_size", None),
72
  ocr_texts=ocr_texts,
73
  ocr_bboxes=ocr_bboxes,
74
- bbox_threshold=data.get("bbox_threshold", 0.05),
75
- iou_threshold=data.get("iou_threshold", None),
76
  )
77
- return {
78
- "image": annotated_image,
79
- "bboxes": filtered_bboxes_out,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
81
 
 
 
82
  def check_ocr_bboxes(
83
  self,
84
  image: ImageType,
 
49
  self.annotator = BoxAnnotator()
50
 
51
  def __call__(self, data: Dict[str, Any]) -> Any:
52
+ # Flexible input contract:
53
+ # 1) OpenAI-style: { messages: [ { role:'user', content:[ {type:'image_url', image_url:{url:data-url}}, {type:'text', text:'...'} ] } ] }
54
+ # 2) Legacy HF inputs: { inputs: { image: <url|data-url> } }
55
+ # 3) PropParse-style: { inputs: { image_b64: <data-url>, ... } }
56
+
57
+ # Normalize payload
58
+ payload: Dict[str, Any]
59
+ if isinstance(data, dict) and "inputs" in data:
60
+ payload = data.get("inputs") or {}
61
+ else:
62
+ payload = data
63
+
64
+ # Extract image source
65
+ img_source: Optional[str] = None
66
+ if "image_b64" in payload and isinstance(payload["image_b64"], str):
67
+ img_source = payload["image_b64"]
68
+ elif "image" in payload and isinstance(payload["image"], str):
69
+ img_source = payload["image"]
70
+ elif isinstance(payload.get("messages"), list):
71
+ for msg in payload["messages"]:
72
+ if isinstance(msg, Dict) and msg.get("role") == "user":
73
+ for part in msg.get("content", []):
74
+ if part.get("type") == "image_url":
75
+ img_source = part.get("image_url", {}).get("url")
76
+ break
77
+ if img_source:
78
+ break
79
+
80
+ if not img_source:
81
+ return {"error": "No image provided (image/image_b64/messages)"}
82
+
83
+ # Load image from data URL or external URL
84
+ try:
85
+ if isinstance(img_source, str) and img_source.startswith("data:"):
86
+ header, b64data = img_source.split(",", 1)
87
+ decoded = base64.b64decode(b64data)
88
+ image = Image.open(io.BytesIO(decoded))
89
+ image.load()
90
+ else:
91
+ image = load_image(img_source)
92
+ except Exception as e:
93
+ return {"error": f"Failed to load image: {e}"}
94
 
95
  ocr_texts, ocr_bboxes = self.check_ocr_bboxes(
96
  image,
 
99
  )
100
  annotated_image, filtered_bboxes_out = self.get_som_labeled_img(
101
  image,
102
+ image_size=payload.get("image_size", None),
103
  ocr_texts=ocr_texts,
104
  ocr_bboxes=ocr_bboxes,
105
+ bbox_threshold=payload.get("bbox_threshold", 0.05),
106
+ iou_threshold=payload.get("iou_threshold", None),
107
  )
108
+
109
+ # Legacy fields
110
+ legacy = {"image": annotated_image, "bboxes": filtered_bboxes_out}
111
+
112
+ # PropParse-style fields
113
+ try:
114
+ w, h = image.size # type: ignore
115
+ except Exception:
116
+ w, h = None, None
117
+ annotated_data_url = (
118
+ f"data:image/png;base64,{annotated_image}"
119
+ if isinstance(annotated_image, str) and not annotated_image.startswith("data:")
120
+ else annotated_image
121
+ )
122
+ elements = [
123
+ {
124
+ "type": box.get("type", "icon"),
125
+ "bbox_xyxy_norm": box.get("bbox"),
126
+ "interactivity": box.get("interactivity", True),
127
+ "content": box.get("content"),
128
+ }
129
+ for box in filtered_bboxes_out
130
+ if isinstance(box.get("bbox"), list)
131
+ ]
132
+ propparse_style = {
133
+ "annotated_image": annotated_data_url,
134
+ "boxes": {"elements": elements},
135
+ **({"width": w, "height": h} if w and h else {}),
136
  }
137
 
138
+ return {**legacy, **propparse_style}
139
+
140
  def check_ocr_bboxes(
141
  self,
142
  image: ImageType,