|
|
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import random
|
|
|
from collections import deque
|
|
|
from operator import itemgetter
|
|
|
|
|
|
import cv2
|
|
|
import mmengine
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from mmengine import Config, DictAction
|
|
|
from mmengine.dataset import Compose
|
|
|
|
|
|
from mmaction.apis import inference_recognizer, init_recognizer
|
|
|
|
|
|
FONTFACE = cv2.FONT_HERSHEY_COMPLEX_SMALL
|
|
|
FONTSCALE = 1
|
|
|
THICKNESS = 1
|
|
|
LINETYPE = 1
|
|
|
|
|
|
EXCLUED_STEPS = [
|
|
|
'OpenCVInit', 'OpenCVDecode', 'DecordInit', 'DecordDecode', 'PyAVInit',
|
|
|
'PyAVDecode', 'RawFrameDecode'
|
|
|
]
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description='MMAction2 predict different labels in a long video demo')
|
|
|
parser.add_argument('config', help='test config file path')
|
|
|
parser.add_argument('checkpoint', help='checkpoint file/url')
|
|
|
parser.add_argument('video_path', help='video file/url')
|
|
|
parser.add_argument('label', help='label file')
|
|
|
parser.add_argument('out_file', help='output result file in video/json')
|
|
|
parser.add_argument(
|
|
|
'--input-step',
|
|
|
type=int,
|
|
|
default=1,
|
|
|
help='input step for sampling frames')
|
|
|
parser.add_argument(
|
|
|
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
|
|
|
parser.add_argument(
|
|
|
'--threshold',
|
|
|
type=float,
|
|
|
default=0.01,
|
|
|
help='recognition score threshold')
|
|
|
parser.add_argument(
|
|
|
'--stride',
|
|
|
type=float,
|
|
|
default=0,
|
|
|
help=('the prediction stride equals to stride * sample_length '
|
|
|
'(sample_length indicates the size of temporal window from '
|
|
|
'which you sample frames, which equals to '
|
|
|
'clip_len x frame_interval), if set as 0, the '
|
|
|
'prediction stride is 1'))
|
|
|
parser.add_argument(
|
|
|
'--cfg-options',
|
|
|
nargs='+',
|
|
|
action=DictAction,
|
|
|
default={},
|
|
|
help='override some settings in the used config, the key-value pair '
|
|
|
'in xxx=yyy format will be merged into config file. For example, '
|
|
|
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
|
|
|
parser.add_argument(
|
|
|
'--label-color',
|
|
|
nargs='+',
|
|
|
type=int,
|
|
|
default=(255, 255, 255),
|
|
|
help='font color (B, G, R) of the labels in output video')
|
|
|
parser.add_argument(
|
|
|
'--msg-color',
|
|
|
nargs='+',
|
|
|
type=int,
|
|
|
default=(128, 128, 128),
|
|
|
help='font color (B, G, R) of the messages in output video')
|
|
|
args = parser.parse_args()
|
|
|
return args
|
|
|
|
|
|
|
|
|
def show_results_video(result_queue,
|
|
|
text_info,
|
|
|
thr,
|
|
|
msg,
|
|
|
frame,
|
|
|
video_writer,
|
|
|
label_color=(255, 255, 255),
|
|
|
msg_color=(128, 128, 128)):
|
|
|
if len(result_queue) != 0:
|
|
|
text_info = {}
|
|
|
results = result_queue.popleft()
|
|
|
for i, result in enumerate(results):
|
|
|
selected_label, score = result
|
|
|
if score < thr:
|
|
|
break
|
|
|
location = (0, 40 + i * 20)
|
|
|
text = selected_label + ': ' + str(round(score, 2))
|
|
|
text_info[location] = text
|
|
|
cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
|
|
|
label_color, THICKNESS, LINETYPE)
|
|
|
elif len(text_info):
|
|
|
for location, text in text_info.items():
|
|
|
cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
|
|
|
label_color, THICKNESS, LINETYPE)
|
|
|
else:
|
|
|
cv2.putText(frame, msg, (0, 40), FONTFACE, FONTSCALE, msg_color,
|
|
|
THICKNESS, LINETYPE)
|
|
|
video_writer.write(frame)
|
|
|
return text_info
|
|
|
|
|
|
|
|
|
def get_results_json(result_queue, text_info, thr, msg, ind, out_json):
|
|
|
if len(result_queue) != 0:
|
|
|
text_info = {}
|
|
|
results = result_queue.popleft()
|
|
|
for i, result in enumerate(results):
|
|
|
selected_label, score = result
|
|
|
if score < thr:
|
|
|
break
|
|
|
text_info[i + 1] = selected_label + ': ' + str(round(score, 2))
|
|
|
out_json[ind] = text_info
|
|
|
elif len(text_info):
|
|
|
out_json[ind] = text_info
|
|
|
else:
|
|
|
out_json[ind] = msg
|
|
|
return text_info, out_json
|
|
|
|
|
|
|
|
|
def show_results(model, data, label, args):
|
|
|
frame_queue = deque(maxlen=args.sample_length)
|
|
|
result_queue = deque(maxlen=1)
|
|
|
|
|
|
cap = cv2.VideoCapture(args.video_path)
|
|
|
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
msg = 'Preparing action recognition ...'
|
|
|
text_info = {}
|
|
|
out_json = {}
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
frame_size = (frame_width, frame_height)
|
|
|
|
|
|
ind = 0
|
|
|
video_writer = None if args.out_file.endswith('.json') \
|
|
|
else cv2.VideoWriter(args.out_file, fourcc, fps, frame_size)
|
|
|
prog_bar = mmengine.ProgressBar(num_frames)
|
|
|
backup_frames = []
|
|
|
|
|
|
while ind < num_frames:
|
|
|
ind += 1
|
|
|
prog_bar.update()
|
|
|
ret, frame = cap.read()
|
|
|
if frame is None:
|
|
|
|
|
|
continue
|
|
|
backup_frames.append(np.array(frame)[:, :, ::-1])
|
|
|
if ind == args.sample_length:
|
|
|
|
|
|
frame_queue.extend(backup_frames)
|
|
|
backup_frames = []
|
|
|
elif ((len(backup_frames) == args.input_step
|
|
|
and ind > args.sample_length) or ind == num_frames):
|
|
|
|
|
|
|
|
|
chosen_frame = random.choice(backup_frames)
|
|
|
backup_frames = []
|
|
|
frame_queue.append(chosen_frame)
|
|
|
|
|
|
ret, scores = inference(model, data, args, frame_queue)
|
|
|
|
|
|
if ret:
|
|
|
num_selected_labels = min(len(label), 5)
|
|
|
scores_tuples = tuple(zip(label, scores))
|
|
|
scores_sorted = sorted(
|
|
|
scores_tuples, key=itemgetter(1), reverse=True)
|
|
|
results = scores_sorted[:num_selected_labels]
|
|
|
result_queue.append(results)
|
|
|
|
|
|
if args.out_file.endswith('.json'):
|
|
|
text_info, out_json = get_results_json(result_queue, text_info,
|
|
|
args.threshold, msg, ind,
|
|
|
out_json)
|
|
|
else:
|
|
|
text_info = show_results_video(result_queue, text_info,
|
|
|
args.threshold, msg, frame,
|
|
|
video_writer, args.label_color,
|
|
|
args.msg_color)
|
|
|
|
|
|
cap.release()
|
|
|
if video_writer:
|
|
|
video_writer.release()
|
|
|
cv2.destroyAllWindows()
|
|
|
if args.out_file.endswith('.json'):
|
|
|
with open(args.out_file, 'w') as js:
|
|
|
json.dump(out_json, js)
|
|
|
|
|
|
|
|
|
def inference(model, data, args, frame_queue):
|
|
|
if len(frame_queue) != args.sample_length:
|
|
|
|
|
|
return False, None
|
|
|
|
|
|
cur_windows = list(np.array(frame_queue))
|
|
|
if data['img_shape'] is None:
|
|
|
data['img_shape'] = frame_queue[0].shape[:2]
|
|
|
|
|
|
cur_data = data.copy()
|
|
|
cur_data.update(
|
|
|
dict(
|
|
|
array=cur_windows,
|
|
|
modality='RGB',
|
|
|
frame_inds=np.arange(args.sample_length)))
|
|
|
|
|
|
result = inference_recognizer(
|
|
|
model, cur_data, test_pipeline=args.test_pipeline)
|
|
|
scores = result.pred_score.tolist()
|
|
|
|
|
|
if args.stride > 0:
|
|
|
pred_stride = int(args.sample_length * args.stride)
|
|
|
for _ in range(pred_stride):
|
|
|
frame_queue.popleft()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return True, scores
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
|
|
|
args.device = torch.device(args.device)
|
|
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
|
|
|
|
model = init_recognizer(cfg, args.checkpoint, device=args.device)
|
|
|
data = dict(img_shape=None, modality='RGB', label=-1)
|
|
|
with open(args.label, 'r') as f:
|
|
|
label = [line.strip() for line in f]
|
|
|
|
|
|
|
|
|
cfg = model.cfg
|
|
|
sample_length = 0
|
|
|
pipeline = cfg.test_pipeline
|
|
|
pipeline_ = pipeline.copy()
|
|
|
for step in pipeline:
|
|
|
if 'SampleFrames' in step['type']:
|
|
|
sample_length = step['clip_len'] * step['num_clips']
|
|
|
data['num_clips'] = step['num_clips']
|
|
|
data['clip_len'] = step['clip_len']
|
|
|
pipeline_.remove(step)
|
|
|
if step['type'] in EXCLUED_STEPS:
|
|
|
|
|
|
pipeline_.remove(step)
|
|
|
pipeline_.insert(1, dict(type='ArrayDecode'))
|
|
|
test_pipeline = Compose(pipeline_)
|
|
|
|
|
|
assert sample_length > 0
|
|
|
args.sample_length = sample_length
|
|
|
args.test_pipeline = test_pipeline
|
|
|
|
|
|
show_results(model, data, label, args)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|