IgorSlinko's picture
Initial commit: Squircle corners prediction app with Gemini and YOLO
e97f848
import math
from dataclasses import dataclass
SQRT2 = math.sqrt(2)
@dataclass
class Circle:
x: int
y: int
r: int
@dataclass
class Corners:
top_left: Circle
top_right: Circle
bottom_left: Circle
bottom_right: Circle
@classmethod
def from_opposite_corners(cls, top_left: Circle, bottom_right: Circle) -> "Corners":
r_tl = top_left.r
r_br = bottom_right.r
r_max = max(r_tl, r_br)
x_tl, y_tl = float(top_left.x), float(top_left.y)
x_br, y_br = float(bottom_right.x), float(bottom_right.y)
if r_tl < r_max:
diff = r_max - r_tl
dx = x_br - x_tl
dy = y_br - y_tl
dist = math.sqrt(dx * dx + dy * dy)
if dist > 0:
x_tl = x_tl + diff * dx / dist
y_tl = y_tl + diff * dy / dist
if r_br < r_max:
diff = r_max - r_br
dx = x_tl - x_br
dy = y_tl - y_br
dist = math.sqrt(dx * dx + dy * dy)
if dist > 0:
x_br = x_br + diff * dx / dist
y_br = y_br + diff * dy / dist
x_tl, y_tl = int(round(x_tl)), int(round(y_tl))
x_br, y_br = int(round(x_br)), int(round(y_br))
return cls(
top_left=Circle(x_tl, y_tl, r_max),
top_right=Circle(x_br, y_tl, r_max),
bottom_left=Circle(x_tl, y_br, r_max),
bottom_right=Circle(x_br, y_br, r_max),
)
@classmethod
def from_boxes(cls, boxes: list[tuple[int, int, int, int]]) -> "Corners | None":
if len(boxes) < 4:
return None
box_centers = []
for x1, y1, x2, y2 in boxes[:4]:
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
box_centers.append((cx, cy, x1, y1, x2, y2))
sorted_by_y = sorted(box_centers, key=lambda b: b[1])
top_two = sorted(sorted_by_y[:2], key=lambda b: b[0])
bottom_two = sorted(sorted_by_y[2:], key=lambda b: b[0])
def box_to_circle(box_data: tuple, corner: str) -> Circle:
_, _, x1, y1, x2, y2 = box_data
r = (x2 - x1 + y2 - y1) // 2
if corner == "top_left":
return Circle(x2, y2, r)
elif corner == "top_right":
return Circle(x1, y2, r)
elif corner == "bottom_left":
return Circle(x2, y1, r)
else:
return Circle(x1, y1, r)
return cls(
top_left=box_to_circle(top_two[0], "top_left"),
top_right=box_to_circle(top_two[1], "top_right"),
bottom_left=box_to_circle(bottom_two[0], "bottom_left"),
bottom_right=box_to_circle(bottom_two[1], "bottom_right"),
)
def to_dict(self) -> dict:
return {
"top_left": {"x": self.top_left.x, "y": self.top_left.y, "r": self.top_left.r},
"top_right": {"x": self.top_right.x, "y": self.top_right.y, "r": self.top_right.r},
"bottom_left": {"x": self.bottom_left.x, "y": self.bottom_left.y, "r": self.bottom_left.r},
"bottom_right": {"x": self.bottom_right.x, "y": self.bottom_right.y, "r": self.bottom_right.r},
}
def get_crop_bounds(self) -> tuple[int, int, int, int]:
x_left_tl = self.top_left.x - int(self.top_left.r / SQRT2)
x_left_bl = self.bottom_left.x - int(self.bottom_left.r / SQRT2)
x_left = max(x_left_tl, x_left_bl)
x_right_tr = self.top_right.x + int(self.top_right.r / SQRT2)
x_right_br = self.bottom_right.x + int(self.bottom_right.r / SQRT2)
x_right = min(x_right_tr, x_right_br)
y_top_tl = self.top_left.y - int(self.top_left.r / SQRT2)
y_top_tr = self.top_right.y - int(self.top_right.r / SQRT2)
y_top = max(y_top_tl, y_top_tr)
y_bottom_bl = self.bottom_left.y + int(self.bottom_left.r / SQRT2)
y_bottom_br = self.bottom_right.y + int(self.bottom_right.r / SQRT2)
y_bottom = min(y_bottom_bl, y_bottom_br)
return x_left, y_top, x_right, y_bottom