echotracker / utils.py
riponazad
deploy 1.0
21494a2
import torch
import cv2
import numpy as np
from skimage.color import gray2rgb
def points_to_tensor(points: list, qt: int, orig_H: int, orig_W: int, target: int = 256) -> torch.Tensor:
"""
Convert [(x1,y1), ..., (xn,yn)] to tensor of shape [1, n, 3]
where last dim is (qt, x, y), with x/y scaled to target resolution.
Args:
points : list of (x, y) tuples or np.array([x, y])
qt : single int, same for all points
orig_H : original frame height
orig_W : original frame width
target : target resolution (default 256)
Returns:
tensor of shape [1, n, 3], dtype float32
"""
scale_x = target / orig_W
scale_y = target / orig_H
arr = np.array(
[[qt, p[0] * scale_x, p[1] * scale_y] for p in points],
dtype=np.float32
) # (n, 3)
return torch.tensor(arr).unsqueeze(0) # (1, n, 3)
def visualize_tracking(
frames: np.ndarray,
points: np.ndarray,
tracking_quality: np.ndarray = None,
vis_color='random',
color_map: np.ndarray = None,
gray: bool = False,
alpha: float = 1.0,
track_length: int = 0,
thickness: int = 2,
) -> np.ndarray:
num_points, num_frames = points.shape[:2]
height, width = frames.shape[1:3]
if gray and frames.shape[-1] != 3:
frames = gray2rgb(frames.squeeze())
radius = max(6, int(0.006 * min(height, width)))
quality_colors = {
0: np.array([255, 0, 0]),
1: np.array([255, 255, 0]),
2: np.array([0, 255, 0]),
}
video = frames.copy()
# Stable random colors
if vis_color == 'random' and tracking_quality is None and color_map is None:
rand_colors = np.random.randint(0, 256, size=(num_points, 3))
for t in range(num_frames):
overlay = np.zeros_like(video[t], dtype=np.uint8)
t_start = max(1, t - track_length)
for i in range(num_points):
# -------------------------------------------------
# Resolve color ONCE (fixes UnboundLocalError)
# -------------------------------------------------
if tracking_quality is not None:
color = quality_colors.get(
int(tracking_quality[i, t]),
np.array([255, 255, 255])
)
elif color_map is not None:
color = np.asarray(color_map[i])
elif isinstance(vis_color, (list, tuple, np.ndarray)):
color = np.asarray(vis_color)
else:
if vis_color == 'random':
color = rand_colors[i]
elif vis_color == 'red':
color = quality_colors[0]
elif vis_color == 'yellow':
color = quality_colors[1]
elif vis_color == 'green':
color = quality_colors[2]
else:
raise ValueError(f"Unknown vis_color: {vis_color}")
color = color.astype(np.uint8)
# -------------------------------------------------
# Draw track lines
# -------------------------------------------------
for tt in range(t_start, t):
fade = (tt - t_start + 1) / max(1, (t - t_start))
x0n, y0n = points[i, tt - 1]
x1n, y1n = points[i, tt]
x0 = int(np.clip(x0n * width, 0, width - 1))
y0 = int(np.clip(y0n * height, 0, height - 1))
x1 = int(np.clip(x1n * width, 0, width - 1))
y1 = int(np.clip(y1n * height, 0, height - 1))
faded_color = (color * fade).astype(np.uint8)
cv2.line(
overlay,
(x0, y0),
(x1, y1),
faded_color.tolist(),
thickness=thickness,
lineType=cv2.LINE_AA
)
# -------------------------------------------------
# Draw dot (current position)
# -------------------------------------------------
xc = int(points[i, t, 0] * width)
yc = int(points[i, t, 1] * height)
cv2.circle(
overlay,
(xc, yc),
radius=radius,
color=color.tolist(),
thickness=-1
)
video[t] = cv2.addWeighted(video[t], 1.0, overlay, alpha, 0)
return video