|
|
import re |
|
|
from typing import Any, Callable, Dict, List, Tuple |
|
|
|
|
|
|
|
|
def get_region_transform( |
|
|
region_type: str = "bbox", |
|
|
region_format: str = "xyxy", |
|
|
coordinate_format: str = "000", |
|
|
coord_decimal: int = 3, |
|
|
) -> Tuple[Callable, int]: |
|
|
|
|
|
transforms = RegionTransform( |
|
|
region_type=region_type, |
|
|
region_format=region_format, |
|
|
coordinate_format=coordinate_format, |
|
|
coord_decimal=coord_decimal, |
|
|
) |
|
|
|
|
|
return transforms |
|
|
|
|
|
|
|
|
class RegionTransform(object): |
|
|
def __init__( |
|
|
self, |
|
|
region_type: str = "bbox", |
|
|
region_format: str = "xyxy", |
|
|
coordinate_format: str = "000", |
|
|
coord_decimal: int = 3, |
|
|
): |
|
|
assert region_type in ["bbox", "mask"] |
|
|
assert region_format in ["xyxy", "xywh", "polygon"] |
|
|
assert coordinate_format in ["000", "standard"] |
|
|
|
|
|
self.region_type = region_type |
|
|
self.region_format = region_format |
|
|
self.coordinate_format = coordinate_format |
|
|
self.coord_decimal = coord_decimal |
|
|
|
|
|
def clamp(self, x: float, min_x: float, max_x: float) -> float: |
|
|
return max(min(x, max_x), min_x) |
|
|
|
|
|
def format_bounding_box( |
|
|
self, |
|
|
box: List[float], |
|
|
box_format: str = "000", |
|
|
coord_decimal: int = 3, |
|
|
) -> str: |
|
|
box = [self.clamp(b, 0.0, 0.999) for b in box] |
|
|
if box_format == "standard": |
|
|
|
|
|
box = ( |
|
|
"[" |
|
|
+ ",".join( |
|
|
[ |
|
|
(f"%.{coord_decimal}f" % b)[::-1].zfill(coord_decimal + 2)[::-1] |
|
|
for b in box |
|
|
] |
|
|
) |
|
|
+ "]" |
|
|
) |
|
|
elif box_format == "000": |
|
|
box = ( |
|
|
"[" |
|
|
+ ",".join( |
|
|
[ |
|
|
str(int(b * (10**coord_decimal))).zfill(coord_decimal) |
|
|
for b in box |
|
|
] |
|
|
) |
|
|
+ "]" |
|
|
) |
|
|
return box |
|
|
|
|
|
def _transform_regions(self, regions: List[Any], img_w: float, img_h: float): |
|
|
regions_out = [] |
|
|
for region in regions: |
|
|
if self.region_type == "bbox": |
|
|
|
|
|
x, y, w, h = region |
|
|
if self.region_format == "xyxy": |
|
|
region = [ |
|
|
x / float(img_w), |
|
|
y / float(img_h), |
|
|
(x + w) / float(img_w), |
|
|
(y + h) / float(img_h), |
|
|
] |
|
|
elif self.region_format == "xywh": |
|
|
region = [x / img_w, y / img_h, w / img_w, h / img_h] |
|
|
else: |
|
|
raise ValueError(f"Unknown region format: {self.region_format}") |
|
|
|
|
|
|
|
|
region_out = self.format_bounding_box( |
|
|
region, self.coordinate_format, self.coord_decimal |
|
|
) |
|
|
|
|
|
regions_out.append(region_out) |
|
|
else: |
|
|
raise ValueError(f"Unknown region type: {self.region_type}") |
|
|
return regions_out |
|
|
|
|
|
def _transform_conv(self, conv: str, regions: List[str]) -> str: |
|
|
if self.region_type == "bbox": |
|
|
pattern = re.compile(r"<\|bbox(\d+)\|>") |
|
|
elif self.region_type == "mask": |
|
|
pattern = re.compile(r"<\|mask(\d+)\|>") |
|
|
else: |
|
|
raise ValueError(f"Unknown region type: {self.region_type}") |
|
|
|
|
|
matches = pattern.finditer(conv) |
|
|
|
|
|
|
|
|
indices = [(match.start(), match.end()) for match in matches] |
|
|
|
|
|
|
|
|
conv_out = "" |
|
|
start_idx = 0 |
|
|
for i, j in indices: |
|
|
conv_out += conv[start_idx:i] |
|
|
|
|
|
region_idx = int(conv[i + len("<|bbox") : j - len("|>")]) |
|
|
conv_out += regions[region_idx] |
|
|
start_idx = j |
|
|
conv_out += conv[start_idx:] |
|
|
return conv_out |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
convs: List[Dict[str, Any]], |
|
|
regions: List[Any], |
|
|
img_w: float, |
|
|
img_h: float, |
|
|
) -> List[Dict[str, Any]]: |
|
|
|
|
|
regions = self._transform_regions(regions, img_w, img_h) |
|
|
|
|
|
|
|
|
for conv in convs: |
|
|
conv["value"] = self._transform_conv(conv["value"], regions) |
|
|
|
|
|
return convs |
|
|
|