File size: 3,609 Bytes
210e540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
from tqdm import tqdm
from typing import List, Tuple, Sequence, Any

FOOTBALL_KEYPOINTS: list[tuple[int, int]] = [
    (0, 0),  # 1
    (0, 0),  # 2
    (0, 0),  # 3
    (0, 0),  # 4
    (0, 0),  # 5
    (0, 0),  # 6
    
    (0, 0),  # 7
    (0, 0),  # 8
    (0, 0),  # 9

    (0, 0),  # 10
    (0, 0),  # 11
    (0, 0),  # 12
    (0, 0),  # 13

    (0, 0),  # 14
    (527, 283),  # 15
    (527, 403),  # 16
    (0, 0),  # 17

    (0, 0),  # 18
    (0, 0),  # 19
    (0, 0),  # 20
    (0, 0),  # 21

    (0, 0),  # 22

    (0, 0),  # 23
    (0, 0),  # 24

    (0, 0),  # 25
    (0, 0),  # 26
    (0, 0),  # 27
    (0, 0),  # 28
    (0, 0),  # 29
    (0, 0),  # 30

    (405, 340),  # 31
    (645, 340),  # 32
]

def convert_keypoints_to_val_format(keypoints):
    return [tuple(int(x) for x in pair) for pair in keypoints]

def predict_failed_indices(results_frames: Sequence[Any]) -> list[int]:

    max_frames = len(results_frames)
    if max_frames == 0:
        return []

    failed_indices: list[int] = []
    for frame_index, frame_result in enumerate(results_frames):
        frame_keypoints = getattr(frame_result, "keypoints", []) or []
        non_zero_count = sum(1 for (x, y) in frame_keypoints if int(x) != 0 and int(y) != 0)
        if non_zero_count <= 4:
            failed_indices.append(frame_index)
    return failed_indices

def _generate_sparse_template_keypoints(frame_width: int, frame_height: int) -> list[tuple[int, int]]:
    template_max_x, template_max_y = (1045, 675)
    sx = float(frame_width) / float(template_max_x if template_max_x != 0 else 1)
    sy = float(frame_height) / float(template_max_y if template_max_y != 0 else 1)
    scaled: list[tuple[int, int]] = []
    for i in range(32):
        tx, ty = FOOTBALL_KEYPOINTS[i]
        x_scaled = int(round(tx * sx))
        y_scaled = int(round(ty * sy))
        scaled.append((x_scaled, y_scaled))
    return scaled

def fix_keypoints(
    results_frames: Sequence[Any],
    failed_indices: Sequence[int],
    frame_width: int,
    frame_height: int,
) -> list[Any]:
    max_frames = len(results_frames)
    if max_frames == 0:
        return list(results_frames)

    failed_set = set(int(i) for i in failed_indices)
    all_indices = list(range(max_frames))
    successful_indices = [i for i in all_indices if i not in failed_set]

    if len(successful_indices) == 0:
        sparse_template = _generate_sparse_template_keypoints(frame_width, frame_height)
        for frame_result in results_frames:
            setattr(frame_result, "keypoints", list(convert_keypoints_to_val_format(sparse_template)))
        return list(results_frames)

    seed_index = successful_indices[0]
    seed_kps_raw = getattr(results_frames[seed_index], "keypoints", []) or []
    last_success_kps = convert_keypoints_to_val_format(seed_kps_raw)

    for frame_index in range(max_frames):
        frame_result = results_frames[frame_index]
        if frame_index in failed_set:
            setattr(frame_result, "keypoints", list(last_success_kps))
        else:
            current_kps_raw = getattr(frame_result, "keypoints", []) or []
            current_kps = convert_keypoints_to_val_format(current_kps_raw)
            setattr(frame_result, "keypoints", list(current_kps))
            last_success_kps = current_kps

    return list(results_frames)

def run_keypoints_post_processing(results_frames: Sequence[Any], frame_width: int, frame_height: int) -> list[Any]:
    failed_indices = predict_failed_indices(results_frames)
    return fix_keypoints(results_frames, failed_indices, frame_width, frame_height)