P-DFD / extract_video.py
mrneuralnet's picture
Modify inference script args
f4c3cd9
import os
from os.path import join
import argparse
import numpy as np
import cv2
import torch
from tqdm import tqdm
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
from models.retinaface import RetinaFace
from utils.box_utils import decode
np.random.seed(0)
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu=False):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage)
else:
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(
pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
model.to(device)
return model
def detect(img_list, output_path, resize=1):
os.makedirs(output_path, exist_ok=True)
im_height, im_width, _ = img_list[0].shape
scale = torch.Tensor([im_width, im_height, im_width, im_height])
img_x = torch.stack(img_list, dim=0).permute([0, 3, 1, 2])
scale = scale.to(device)
# batch size
batch_size = args.bs
# forward times
f_times = img_x.shape[0] // batch_size
if img_x.shape[0] % batch_size != 0:
f_times += 1
locs_list = list()
confs_list = list()
for _ in range(f_times):
if _ != f_times - 1:
batch_img_x = img_x[_ * batch_size:(_ + 1) * batch_size]
else:
batch_img_x = img_x[_ * batch_size:] # last batch
batch_img_x = batch_img_x.to(device).float()
l, c, _ = net(batch_img_x)
locs_list.append(l)
confs_list.append(c)
locs = torch.cat(locs_list, dim=0)
confs = torch.cat(confs_list, dim=0)
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
img_cpu = img_x.permute([0, 2, 3, 1]).cpu().numpy()
i = 0
for img, loc, conf in zip(img_cpu, locs, confs):
boxes = decode(loc.data, prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.data.cpu().numpy()[:, 1]
# ignore low scores
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:args.top_k]
boxes = boxes[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(
np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
# keep top-K faster NMS
dets = dets[:args.keep_top_k, :]
if len(dets) == 0:
continue
det = list(map(int, dets[0]))
x, y, size_bb_x, size_bb_y = get_boundingbox(det, img.shape[1], img.shape[0])
cropped_img = img[y:y + size_bb_y, x:x + size_bb_x, :] + (104, 117, 123)
cv2.imwrite(join(output_path, '{:04d}.png'.format(i)), cropped_img)
i += 1
pass
def extract_frames(data_path, interval=1):
"""Method to extract frames"""
if data_path.split('.')[-1] == "mp4":
reader = cv2.VideoCapture(data_path)
frame_num = 0
frames = list()
while reader.isOpened():
success, image = reader.read()
if not success:
break
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.tensor(image) - torch.tensor([104, 117, 123])
if frame_num % interval == 0:
frames.append(image)
frame_num += 1
if len(frames) > args.max_frames:
break
reader.release()
if len(frames) > args.max_frames:
samples = np.random.choice(
np.arange(0, len(frames)), size=args.max_frames, replace=False)
return [frames[_] for _ in samples]
return frames
else:
image = cv2.imread(data_path)
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.tensor(image) - torch.tensor([104, 117, 123])
return [image]
def get_boundingbox(bbox, width, height, scale=1.8, minsize=None):
x1 = bbox[0]
y1 = bbox[1]
x2 = bbox[2]
y2 = bbox[3]
size_bb_x = int((x2 - x1) * scale)
size_bb_y = int((y2 - y1) * scale)
if minsize:
if size_bb_x < minsize:
size_bb_x = minsize
if size_bb_y < minsize:
size_bb_y = minsize
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
# Check for out of bounds, x-y top left corner
x1 = max(int(center_x - size_bb_x // 2), 0)
y1 = max(int(center_y - size_bb_y // 2), 0)
# Check for too big bb size for given x, y
size_bb_x = min(width - x1, size_bb_x)
size_bb_y = min(height - y1, size_bb_y)
return x1, y1, size_bb_x, size_bb_y
def extract_method_videos(data_path, interval):
video = data_path.split('/')[-1]
result_path = '/'.join(data_path.split('/')[:-1])
images_path = join(result_path, 'images')
image_folder = video.split('.')[0]
try:
image_list = extract_frames(data_path, interval)
detect(image_list, join(images_path, image_folder))
except Exception as ex:
f = open("failure.txt", "a", encoding="utf-8")
f.writelines(image_folder +
f" Exception for {image_folder}: {ex}\n")
f.close()
if __name__ == '__main__':
p = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
p.add_argument('--data_path', '-p', type=str, help='path to the data')
p.add_argument('--confidence_threshold', default=0.05,
type=float, help='confidence threshold')
p.add_argument('--top_k', default=5, type=int, help='top_k')
p.add_argument('--nms_threshold', default=0.4,
type=float, help='nms threshold')
p.add_argument('--keep_top_k', default=1, type=int, help='keep_top_k')
p.add_argument('--bs', default=32, type=int, help='batch size')
p.add_argument('--frame_interval', '-fi', default=1, type=int, help='frame interval')
p.add_argument('--device', "-d", default="cuda:0", type=str, help='device')
p.add_argument('--max_frames', default=100, type=int, help='maximum frames per video')
args = p.parse_args()
torch.set_grad_enabled(False)
# use resnet-50
cfg = cfg_re50
pretrained_weights = './weights/Resnet50_Final.pth'
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
print(device)
# net and model
print('loading the model...')
net = RetinaFace(cfg=cfg, phase='test')
net = load_model(net, pretrained_weights, load_to_cpu=False if device=='cuda' else True)
net.eval()
print('Finished loading model!')
extract_method_videos(args.data_path, args.frame_interval)