tanthinhdt commited on
Commit
b808ec1
·
verified ·
1 Parent(s): d13d732

feat: add utils for inference

Browse files
Files changed (1) hide show
  1. utils.py +190 -0
utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from mediapipe.python.solutions import (drawing_styles, drawing_utils,
5
+ holistic, pose)
6
+ from torchvision.transforms.v2 import Compose, UniformTemporalSubsample
7
+
8
+
9
+ def draw_skeleton_on_image(
10
+ image: np.ndarray,
11
+ detection_results,
12
+ resize_to: tuple[int, int] = None,
13
+ ) -> np.ndarray:
14
+ """
15
+ Draw skeleton on the image.
16
+
17
+ Parameters
18
+ ----------
19
+ image : np.ndarray
20
+ Image to draw skeleton on.
21
+ detection_results
22
+ Detection results.
23
+ resize_to : tuple[int, int], optional
24
+ Resize the image to the specified size.
25
+
26
+ Returns
27
+ -------
28
+ np.ndarray
29
+ Annotated image with skeleton.
30
+ """
31
+ annotated_image = np.copy(image)
32
+
33
+ # Draw pose connections
34
+ drawing_utils.draw_landmarks(
35
+ annotated_image,
36
+ detection_results.pose_landmarks,
37
+ holistic.POSE_CONNECTIONS,
38
+ landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(),
39
+ )
40
+ # Draw left hand connections
41
+ drawing_utils.draw_landmarks(
42
+ annotated_image,
43
+ detection_results.left_hand_landmarks,
44
+ holistic.HAND_CONNECTIONS,
45
+ drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
46
+ drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2),
47
+ )
48
+ # Draw right hand connections
49
+ drawing_utils.draw_landmarks(
50
+ annotated_image,
51
+ detection_results.right_hand_landmarks,
52
+ holistic.HAND_CONNECTIONS,
53
+ drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
54
+ drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2),
55
+ )
56
+
57
+ if resize_to is not None:
58
+ annotated_image = cv2.resize(
59
+ annotated_image,
60
+ resize_to,
61
+ interpolation=cv2.INTER_AREA,
62
+ )
63
+ return annotated_image
64
+
65
+
66
+ def are_hands_down(pose_landmarks: list) -> bool:
67
+ """
68
+ Check if the hand is down.
69
+
70
+ Parameters
71
+ ----------
72
+ hand_landmarks : list
73
+ Hand landmarks.
74
+
75
+ Returns
76
+ -------
77
+ bool
78
+ True if the hand is down, False otherwise.
79
+ """
80
+ if pose_landmarks is None:
81
+ return True
82
+
83
+ landmarks = pose_landmarks.landmark
84
+ left_elbow = [
85
+ landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x,
86
+ landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y,
87
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
88
+ ]
89
+ left_wrist = [
90
+ landmarks[pose.PoseLandmark.LEFT_WRIST.value].x,
91
+ landmarks[pose.PoseLandmark.LEFT_WRIST.value].y,
92
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
93
+ ]
94
+ right_elbow = [
95
+ landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x,
96
+ landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y,
97
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
98
+ ]
99
+ right_wrist = [
100
+ landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x,
101
+ landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y,
102
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
103
+ ]
104
+
105
+ is_visible = all(
106
+ [left_elbow[2] > 0, left_wrist[2] > 0, right_elbow[2] > 0, right_wrist[2] > 0]
107
+ )
108
+ return is_visible and left_wrist[1] > left_elbow[1] and right_wrist[1] > right_elbow[1]
109
+
110
+
111
+ def get_predictions(
112
+ inputs: dict,
113
+ model,
114
+ k: int = 3,
115
+ ) -> list:
116
+ if inputs is None:
117
+ return []
118
+
119
+ outputs = model(**inputs)
120
+ logits = outputs.logits
121
+
122
+ # Get top-3 predictions
123
+ topk_scores, topk_indices = torch.topk(logits, k, dim=1)
124
+ topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().cpu().numpy()
125
+ topk_indices = topk_indices.squeeze().cpu().numpy()
126
+
127
+ return [
128
+ {
129
+ 'label': model.config.id2label[topk_indices[i]],
130
+ 'score': topk_scores[i],
131
+ }
132
+ for i in range(k)
133
+ ]
134
+
135
+
136
+ def preprocess(
137
+ model_num_frames: int,
138
+ keypoints_detector,
139
+ source: str,
140
+ data_height: int,
141
+ data_width: int,
142
+ model_input_height: int,
143
+ model_input_width: int,
144
+ device: str,
145
+ transform: Compose,
146
+ ) -> dict:
147
+ skeleton_video = []
148
+ did_sample_start = False
149
+
150
+ cap = cv2.VideoCapture(source)
151
+ while cap.isOpened():
152
+ ret, frame = cap.read()
153
+ if not ret:
154
+ break
155
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
156
+
157
+ # Detect keypoints.
158
+ detection_results = keypoints_detector.process(frame)
159
+ skeleton_frame = draw_skeleton_on_image(
160
+ image=np.zeros((data_height, data_width, 3), dtype=np.uint8),
161
+ detection_results=detection_results,
162
+ resize_to=(model_input_height, model_input_width),
163
+ )
164
+
165
+ # (height, width, channels) -> (channels, height, width)
166
+ skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1))
167
+
168
+ # Extract sign video.
169
+ if not are_hands_down(detection_results.pose_landmarks):
170
+ if not did_sample_start:
171
+ did_sample_start = True
172
+ elif did_sample_start:
173
+ break
174
+
175
+ if did_sample_start:
176
+ skeleton_video.append(skeleton_frame)
177
+
178
+ cap.release()
179
+
180
+ if len(skeleton_video) < model_num_frames:
181
+ return None
182
+
183
+ skeleton_video = torch.stack(skeleton_video)
184
+ skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video)
185
+ inputs = {
186
+ 'pixel_values': skeleton_video.unsqueeze(0),
187
+ }
188
+ inputs = {k: v.to(device) for k, v in inputs.items()}
189
+
190
+ return inputs