"""Output parsing for LocateAnything-3B bounding box responses."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from src.config import COORD_MAX, DEFAULT_CONFIDENCE
@dataclass
class BBox:
"""A parsed bounding box in pixel coordinates."""
x1: float
y1: float
x2: float
y2: float
confidence: float = DEFAULT_CONFIDENCE
label: str = ""
@property
def width(self) -> float:
return max(0.0, self.x2 - self.x1)
@property
def height(self) -> float:
return max(0.0, self.y2 - self.y1)
@property
def area(self) -> float:
return self.width * self.height
@property
def center(self) -> tuple[float, float]:
return ((self.x1 + self.x2) / 2, (self.y1 + self.y2) / 2)
def is_valid(self, img_w: int, img_h: int) -> bool:
"""Check if box is within image bounds and has positive area."""
return (
self.x1 >= 0
and self.y1 >= 0
and self.x2 <= img_w + 1
and self.y2 <= img_h + 1
and self.width > 1
and self.height > 1
)
def clamp(self, img_w: int, img_h: int) -> BBox:
"""Return a clamped copy within image bounds."""
return BBox(
x1=max(0, min(self.x1, img_w)),
y1=max(0, min(self.y1, img_h)),
x2=max(0, min(self.x2, img_w)),
y2=max(0, min(self.y2, img_h)),
confidence=self.confidence,
label=self.label,
)
def to_dict(self) -> dict[str, Any]:
return {
"x1": round(self.x1, 2),
"y1": round(self.y1, 2),
"x2": round(self.x2, 2),
"y2": round(self.y2, 2),
"width": round(self.width, 2),
"height": round(self.height, 2),
"confidence": self.confidence,
"label": self.label,
}
@dataclass
class ParseResult:
"""Structured result from parsing model output."""
boxes: list[BBox] = field(default_factory=list)
raw_output: str = ""
parse_errors: list[str] = field(default_factory=list)
@property
def num_detections(self) -> int:
return len(self.boxes)
def to_dict(self) -> dict[str, Any]:
return {
"num_detections": self.num_detections,
"boxes": [b.to_dict() for b in self.boxes],
"raw_output": self.raw_output,
"parse_errors": self.parse_errors,
}
BOX_PATTERN_4 = re.compile(r"<(\d+)><(\d+)><(\d+)><(\d+)>")
BOX_PATTERN_4_ALT = re.compile(r"\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*")
BOX_PATTERN_2 = re.compile(r"<(\d+)><(\d+)>")
def _norm_to_pixel(val: int, scale: int) -> float:
"""Convert normalized [0, 1000] coordinate to pixel coordinate."""
return val / COORD_MAX * scale
def parse_boxes(
raw_output: str,
image_width: int,
image_height: int,
) -> ParseResult:
"""Parse model output into structured bounding boxes.
The model outputs coordinates normalized to [0, 1000].
This function converts them to pixel coordinates.
"""
result = ParseResult(raw_output=raw_output)
seen: set[tuple[float, float, float, float]] = set()
for match in BOX_PATTERN_4.finditer(raw_output):
try:
x1 = _norm_to_pixel(int(match.group(1)), image_width)
y1 = _norm_to_pixel(int(match.group(2)), image_height)
x2 = _norm_to_pixel(int(match.group(3)), image_width)
y2 = _norm_to_pixel(int(match.group(4)), image_height)
key = (round(x1, 1), round(y1, 1), round(x2, 1), round(y2, 1))
if key not in seen:
seen.add(key)
box = BBox(x1=x1, y1=y1, x2=x2, y2=y2)
if box.is_valid(image_width, image_height):
result.boxes.append(box)
else:
result.parse_errors.append(f"Out-of-bounds box discarded: {key}")
except (ValueError, IndexError) as exc:
result.parse_errors.append(f"Failed to parse box: {exc}")
if not result.boxes:
for match in BOX_PATTERN_4_ALT.finditer(raw_output):
try:
x1 = _norm_to_pixel(int(match.group(1)), image_width)
y1 = _norm_to_pixel(int(match.group(2)), image_height)
x2 = _norm_to_pixel(int(match.group(3)), image_width)
y2 = _norm_to_pixel(int(match.group(4)), image_height)
key = (round(x1, 1), round(y1, 1), round(x2, 1), round(y2, 1))
if key not in seen:
seen.add(key)
box = BBox(x1=x1, y1=y1, x2=x2, y2=y2)
if box.is_valid(image_width, image_height):
result.boxes.append(box)
except (ValueError, IndexError) as exc:
result.parse_errors.append(f"Failed to parse alt box: {exc}")
return result
def parse_points(
raw_output: str,
image_width: int,
image_height: int,
) -> list[dict[str, float]]:
"""Parse model output into pixel-coordinate points."""
points = []
for match in BOX_PATTERN_2.finditer(raw_output):
try:
x = _norm_to_pixel(int(match.group(1)), image_width)
y = _norm_to_pixel(int(match.group(2)), image_height)
points.append({"x": round(x, 2), "y": round(y, 2)})
except (ValueError, IndexError):
pass
return points