File size: 4,530 Bytes
3cf4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import re
from typing import Any, Callable, Dict, List, Tuple
def get_region_transform(
region_type: str = "bbox",
region_format: str = "xyxy",
coordinate_format: str = "000",
coord_decimal: int = 3,
) -> Tuple[Callable, int]:
transforms = RegionTransform(
region_type=region_type,
region_format=region_format,
coordinate_format=coordinate_format,
coord_decimal=coord_decimal,
)
return transforms
class RegionTransform(object):
def __init__(
self,
region_type: str = "bbox",
region_format: str = "xyxy",
coordinate_format: str = "000",
coord_decimal: int = 3,
):
assert region_type in ["bbox", "mask"]
assert region_format in ["xyxy", "xywh", "polygon"]
assert coordinate_format in ["000", "standard"]
self.region_type = region_type
self.region_format = region_format
self.coordinate_format = coordinate_format
self.coord_decimal = coord_decimal
def clamp(self, x: float, min_x: float, max_x: float) -> float:
return max(min(x, max_x), min_x)
def format_bounding_box(
self,
box: List[float],
box_format: str = "000",
coord_decimal: int = 3,
) -> str:
box = [self.clamp(b, 0.0, 0.999) for b in box]
if box_format == "standard":
# NOTE: always make each coordinate 5 tokens (0.11 -> 0.110)
box = (
"["
+ ",".join(
[
(f"%.{coord_decimal}f" % b)[::-1].zfill(coord_decimal + 2)[::-1]
for b in box
]
)
+ "]"
)
elif box_format == "000":
box = (
"["
+ ",".join(
[
str(int(b * (10**coord_decimal))).zfill(coord_decimal)
for b in box
]
)
+ "]"
)
return box
def _transform_regions(self, regions: List[Any], img_w: float, img_h: float):
regions_out = []
for region in regions:
if self.region_type == "bbox":
# region is in [x, y, w, h] format
x, y, w, h = region
if self.region_format == "xyxy":
region = [
x / float(img_w),
y / float(img_h),
(x + w) / float(img_w),
(y + h) / float(img_h),
]
elif self.region_format == "xywh":
region = [x / img_w, y / img_h, w / img_w, h / img_h]
else:
raise ValueError(f"Unknown region format: {self.region_format}")
# Convert boxes into string format
region_out = self.format_bounding_box(
region, self.coordinate_format, self.coord_decimal
)
regions_out.append(region_out)
else:
raise ValueError(f"Unknown region type: {self.region_type}")
return regions_out
def _transform_conv(self, conv: str, regions: List[str]) -> str:
if self.region_type == "bbox":
pattern = re.compile(r"<\|bbox(\d+)\|>")
elif self.region_type == "mask":
pattern = re.compile(r"<\|mask(\d+)\|>")
else:
raise ValueError(f"Unknown region type: {self.region_type}")
matches = pattern.finditer(conv)
# Extract start and end indices of each match
indices = [(match.start(), match.end()) for match in matches]
# Replace each match with the corresponding region
conv_out = ""
start_idx = 0
for i, j in indices:
conv_out += conv[start_idx:i]
region_idx = int(conv[i + len("<|bbox") : j - len("|>")])
conv_out += regions[region_idx]
start_idx = j
conv_out += conv[start_idx:]
return conv_out
def __call__(
self,
convs: List[Dict[str, Any]],
regions: List[Any],
img_w: float,
img_h: float,
) -> List[Dict[str, Any]]:
# 1. transform regions
regions = self._transform_regions(regions, img_w, img_h)
# 2. add regions to convs
for conv in convs:
conv["value"] = self._transform_conv(conv["value"], regions)
return convs
|