turbovision15 / keypoint_utils.py
gloriforge's picture
Upload folder using huggingface_hub
8477400 verified
from numba import njit, prange
@njit(fastmath=True, cache=True)
def score_mask_numba_fast(pred, expected, ground, pixels_on_lines):
h, w = pred.shape
pp = 0
po = 0
for y in prange(h):
for x in range(w):
p_val = pred[y, x]
g_val = ground[y, x]
e_val = expected[y, x]
p = (p_val != 0) & (g_val != 0)
e = e_val != 0
pp += p
po += p & e
if pp == 0:
return 0.0
pr = pp - po
total = pixels_on_lines + pp - po
if total == 0 or pr * 10 > total * 9:
return 0.0
return po / (pixels_on_lines + 1e-8)
@njit(parallel=True, cache=True)
def spmm_csc_templates_parallel(data, indices, indptr, frames_T, out):
F = frames_T.shape[1]
T = indptr.shape[0] - 1
for j in prange(T):
start = indptr[j]
end = indptr[j + 1]
for k in range(start, end):
p = indices[k]
w = data[k]
row = frames_T[p]
for i in range(F):
out[i, j] += row[i] * w