from numba import njit, prange @njit(fastmath=True, cache=True) def score_mask_numba_fast(pred, expected, ground, pixels_on_lines): h, w = pred.shape pp = 0 po = 0 for y in prange(h): for x in range(w): p_val = pred[y, x] g_val = ground[y, x] e_val = expected[y, x] p = (p_val != 0) & (g_val != 0) e = e_val != 0 pp += p po += p & e if pp == 0: return 0.0 pr = pp - po total = pixels_on_lines + pp - po if total == 0 or pr * 10 > total * 9: return 0.0 return po / (pixels_on_lines + 1e-8) @njit(parallel=True, cache=True) def spmm_csc_templates_parallel(data, indices, indptr, frames_T, out): F = frames_T.shape[1] T = indptr.shape[0] - 1 for j in prange(T): start = indptr[j] end = indptr[j + 1] for k in range(start, end): p = indices[k] w = data[k] row = frames_T[p] for i in range(F): out[i, j] += row[i] * w