English
VideoCLIP-XL / demo.py
zarus03's picture
Upload folder using huggingface_hub
e2eafc5 verified
import logging
import os
from typing import List
import cv2
import numpy as np
import torch
from torch import nn
from PIL import Image
import torch.nn.functional as F
from tqdm import tqdm
import threading
from torch._utils import ExceptionWrapper
from utils.text_encoder.simple_tokenizer import SimpleTokenizer as ClipTokenizer
from modeling import VideoCLIP_XL
from utils.text_encoder import text_encoder
import argparse
from data_dataloaders import DATALOADER_DICT
args_parser = argparse.ArgumentParser()
args_parser.add_argument("--datatype", type=str, default="msvd", help="dataset name")
args_parser.add_argument("--local-rank", default=0, type=int, help="distribted training")
args_parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
args_parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
args_parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
args_parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
args = args_parser.parse_args()
logger = logging.getLogger(__name__)
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
def tensor_video_to_text_sim(sim_tensor):
if not torch.is_tensor(sim_tensor):
sim_tensor = torch.tensor(sim_tensor)
# Code to avoid nans
sim_tensor[sim_tensor != sim_tensor] = float('-inf')
# Forms a similarity matrix for use with rank at k
values, _ = torch.max(sim_tensor, dim=1, keepdim=True)
return torch.squeeze(values).T
def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]):
if not torch.is_tensor(sim_tensor):
sim_tensor = torch.tensor(sim_tensor)
# Permute sim_tensor so it represents a sequence of text-video similarity matrices.
# Then obtain the double argsort to position the rank on the diagonal
stacked_sim_matrices = sim_tensor.permute(1, 0, 2)
first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True)
second_argsort = torch.argsort(first_argsort, dim = -1, descending= False)
# Extracts ranks i.e diagonals
ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2))
# Now we need to extract valid ranks, as some belong to inf padding values
permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2))
mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data))
valid_ranks = ranks[mask]
# A quick dimension check validates our results, there may be other correctness tests pending
# Such as dot product localization, but that is for other time.
#assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict])
if not torch.is_tensor(valid_ranks):
valid_ranks = torch.tensor(valid_ranks)
results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k}
results["MedianR"] = float(torch.median(valid_ranks + 1))
results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1))
results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1))
results['MR'] = results["MedianR"]
return results
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
sim_matrix = []
for idx1, b1 in enumerate(batch_list_t):
input_mask, segment_ids, *_tmp = b1
sequence_output = batch_sequence_output_list[idx1]
each_row = []
for idx2, b2 in enumerate(batch_list_v):
video_mask, *_tmp = b2
visual_output = batch_visual_output_list[idx2]
b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
loose_type=model.loose_type)
b1b2_logits = b1b2_logits.cpu().detach().numpy()
each_row.append(b1b2_logits)
each_row = np.concatenate(tuple(each_row), axis=-1)
sim_matrix.append(each_row)
return sim_matrix
def compute_metrics(x):
sx = np.sort(-x, axis=1)
d = np.diag(-x)
d = d[:, np.newaxis]
ind = sx - d
ind = np.where(ind == 0)
ind = ind[1]
metrics = {}
metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind)
metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind)
metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind)
metrics['MR'] = np.median(ind) + 1
metrics["MedianR"] = metrics['MR']
metrics["MeanR"] = np.mean(ind) + 1
metrics["cols"] = [int(i) for i in list(ind)]
return metrics
def eval_epoch(args, model, test_dataloader, device, n_gpu):
if hasattr(model, 'module'):
model = model.module.to(device)
else:
model = model.to(device)
# #################################################################
## below variables are used to multi-sentences retrieval
# multi_sentence_: important tag for eval
# cut_off_points: used to tag the label when calculate the metric
# sentence_num: used to cut the sentence representation
# video_num: used to cut the video representation
# #################################################################
multi_sentence_ = False
cut_off_points_, sentence_num_, video_num_ = [], -1, -1
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') \
and test_dataloader.dataset.multi_sentence_per_video:
multi_sentence_ = True
cut_off_points_ = test_dataloader.dataset.cut_off_points
sentence_num_ = test_dataloader.dataset.sentence_num
video_num_ = test_dataloader.dataset.video_num
cut_off_points_ = [itm - 1 for itm in cut_off_points_]
if multi_sentence_:
logger.warning("Eval under the multi-sentence per video clip setting.")
logger.warning("sentence num: {}, video num: {}".format(sentence_num_, video_num_))
model.eval()
with torch.no_grad():
batch_list_t = []
batch_list_v = []
batch_sequence_output_list, batch_visual_output_list = [], []
total_video_num = 0
# ----------------------------
# 1. cache the features
# ----------------------------
for bid, batch in enumerate(tqdm(test_dataloader)):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask = batch
if multi_sentence_:
# multi-sentences retrieval means: one clip has two or more descriptions.
b, *_t = video.shape
sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
batch_sequence_output_list.append(sequence_output)
batch_list_t.append((input_mask, segment_ids,))
s_, e_ = total_video_num, total_video_num + b
filter_inds = [itm - s_ for itm in cut_off_points_ if itm >= s_ and itm < e_]
if len(filter_inds) > 0:
video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
visual_output = model.get_visual_output(video, video_mask)
batch_visual_output_list.append(visual_output)
batch_list_v.append((video_mask,))
total_video_num += b
else:
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
batch_sequence_output_list.append(sequence_output)
batch_list_t.append((input_mask, segment_ids,))
batch_visual_output_list.append(visual_output)
batch_list_v.append((video_mask,))
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
# ----------------------------------
# 2. calculate the similarity
# ----------------------------------
if n_gpu > 1:
device_ids = list(range(n_gpu))
batch_list_t_splits = []
batch_list_v_splits = []
batch_t_output_splits = []
batch_v_output_splits = []
bacth_len = len(batch_list_t)
split_len = (bacth_len + n_gpu - 1) // n_gpu
for dev_id in device_ids:
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
if dev_id == 0:
batch_list_t_splits.append(batch_list_t[s_:e_])
batch_list_v_splits.append(batch_list_v)
batch_t_output_splits.append(batch_sequence_output_list[s_:e_])
batch_v_output_splits.append(batch_visual_output_list)
else:
devc = torch.device('cuda:{}'.format(str(dev_id)))
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
batch_list_t_splits.append(devc_batch_list)
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
batch_list_v_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
batch_t_output_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
batch_v_output_splits.append(devc_batch_list)
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
sim_matrix = []
for idx in range(len(parallel_outputs)):
sim_matrix += parallel_outputs[idx]
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
else:
sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list)
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
if multi_sentence_:
logger.info("before reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1]))
cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
sim_matrix_new = []
for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
logger.info("after reshape, sim matrix size: {} x {} x {}".
format(sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2]))
tv_metrics = tensor_text_to_video_metrics(sim_matrix)
vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
else:
logger.info("sim matrix size: {}, {}".format(sim_matrix.shape[0], sim_matrix.shape[1]))
tv_metrics = compute_metrics(sim_matrix)
vt_metrics = compute_metrics(sim_matrix.T)
logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))
logger.info("Text-to-Video:")
logger.info('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'.
format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR']))
logger.info("Video-to-Text:")
logger.info('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'.
format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR']))
R1 = tv_metrics['R1']
return R1
def _frame_from_video(video):
while video.isOpened():
success, frame = video.read()
if success:
yield frame
else:
break
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
def normalize(data):
return (data / 255.0 - v_mean) / v_std
def video_preprocessing(video_path, fnum=8):
video = cv2.VideoCapture(video_path)
frames = [x for x in _frame_from_video(video)]
step = len(frames) // fnum
frames = frames[::step][:fnum]
vid_tube = []
for fr in frames:
fr = fr[:,:,::-1]
fr = cv2.resize(fr, (224, 224))
fr = np.expand_dims(normalize(fr), axis=(0, 1))
vid_tube.append(fr)
vid_tube = np.concatenate(vid_tube, axis=1)
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
vid_tube = torch.from_numpy(vid_tube)
return vid_tube
videoclip_xl = VideoCLIP_XL()
state_dict = torch.load("./VideoCLIP-XL.bin", map_location="cpu")
videoclip_xl.load_state_dict(state_dict)
videoclip_xl.cuda().eval()
videos = [
"/path/to/video-1.mp4",
"/path/to/video-2.mp4",
]
texts = [
"text-1",
"text-2",
"text-3",
]
# with torch.no_grad():
# video_inputs = torch.cat([video_preprocessing(video) for video in videos], 0).float().cuda()
# video_features = videoclip_xl.vision_model.get_vid_features(video_inputs).float()
# video_features = video_features / video_features.norm(dim=-1, keepdim=True)
# text_inputs = text_encoder.tokenize(texts, truncate=True).cuda()
# text_features = videoclip_xl.text_model.encode_text(text_inputs).float()
# text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Tmp = 100.
# sim_matrix = (text_features @ video_features.T) * Tmp
# print(f"{type(sim_matrix)=}")
# tv_metrics = compute_metrics(sim_matrix)
# print("Text-to-Video:")
# print(f'\t>>> R@1: {tv_metrics['R1']:.1f} - R@5: {tv_metrics['R5']:.1f} - R@10: {tv_metrics['R10']:.1f} - Median R: {tv_metrics['MR']:.1f} - Mean R: {tv_metrics['MeanR']:.1f}')
tokenizer = ClipTokenizer()
test_dataloader, test_length = None, 0
if DATALOADER_DICT[args.datatype]["test"] is not None:
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
if DATALOADER_DICT[args.datatype]["val"] is not None:
val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
else:
val_dataloader, val_length = test_dataloader, test_length
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.local_rank)
n_gpu = torch.cuda.device_count()
if args.local_rank == 0:
eval_epoch(args, videoclip_xl, test_dataloader, device, n_gpu)