| import os | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| from modeling import VideoCLIP_XL | |
| from utils.text_encoder import text_encoder | |
| def _frame_from_video(video): | |
| while video.isOpened(): | |
| success, frame = video.read() | |
| if success: | |
| yield frame | |
| else: | |
| break | |
| v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3) | |
| v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3) | |
| def normalize(data): | |
| return (data / 255.0 - v_mean) / v_std | |
| def video_preprocessing(video_path, fnum=8): | |
| video = cv2.VideoCapture(video_path) | |
| frames = [x for x in _frame_from_video(video)] | |
| step = len(frames) // fnum | |
| frames = frames[::step][:fnum] | |
| vid_tube = [] | |
| for fr in frames: | |
| fr = fr[:,:,::-1] | |
| fr = cv2.resize(fr, (224, 224)) | |
| fr = np.expand_dims(normalize(fr), axis=(0, 1)) | |
| vid_tube.append(fr) | |
| vid_tube = np.concatenate(vid_tube, axis=1) | |
| vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3)) | |
| vid_tube = torch.from_numpy(vid_tube) | |
| return vid_tube | |
| videoclip_xl = VideoCLIP_XL() | |
| state_dict = torch.load("./VideoCLIP-XL.bin", map_location="cpu") | |
| videoclip_xl.load_state_dict(state_dict) | |
| videoclip_xl.cuda().eval() | |
| videos = [ | |
| "/path/to/video-1.mp4", | |
| "/path/to/video-2.mp4", | |
| ] | |
| texts = [ | |
| "text-1", | |
| "text-2", | |
| "text-3", | |
| ] | |
| with torch.no_grad(): | |
| video_inputs = torch.cat([video_preprocessing(video) for video in videos], 0).float().cuda() | |
| video_features = videoclip_xl.vision_model.get_vid_features(video_inputs).float() | |
| video_features = video_features / video_features.norm(dim=-1, keepdim=True) | |
| text_inputs = text_encoder.tokenize(texts, truncate=True).cuda() | |
| text_features = videoclip_xl.text_model.encode_text(text_inputs).float() | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| Tmp = 100. | |
| sim_matrix = (text_features @ video_features.T) * Tmp | |
| print(sim_matrix) |