import warnings warnings.filterwarnings("ignore", category=FutureWarning) import json import os import subprocess import time import numpy as np import torch import torch.nn.functional as F from decord import VideoReader from transformers import AutoModel, AutoVideoProcessor import src.datasets.utils.video.transforms as video_transforms import src.datasets.utils.video.volume_transforms as volume_transforms from src.models.attentive_pooler import AttentiveClassifier from src.models.vision_transformer import vit_giant_xformers_rope, vit_base IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) def load_pretrained_vjepa_pt_weights(model, pretrained_weights): # Load weights of the VJEPA2 encoder # The PyTorch state_dict is already preprocessed to have the right key names pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"] pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()} msg = model.load_state_dict(pretrained_dict, strict=False) print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) def build_pt_video_transform(img_size): short_side_size = int(256.0 / 224 * img_size) # Eval transform has no random cropping nor flip eval_transform = video_transforms.Compose( [ video_transforms.Resize(short_side_size, interpolation="bilinear"), video_transforms.CenterCrop(size=(img_size, img_size)), volume_transforms.ClipToTensor(), video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ] ) return eval_transform def get_video(sample_video_path, num_frames=80): vr = VideoReader(sample_video_path) total_frames = len(vr) # Choose evenly spaced frames, limited by available frames if total_frames < num_frames: frame_idx = np.arange(0, total_frames, 2) else: frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) video = vr.get_batch(frame_idx).asnumpy() return video def forward_vjepa_video(model_pt, pt_transform, sample_video_path): # Run a sample inference with VJEPA with torch.inference_mode(): # Read and pre-process the image video = get_video(sample_video_path) # T x H x W x C video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W print(video.shape) x_pt = pt_transform(video)[0].cuda().unsqueeze(0) print(x_pt.shape) # Extract the patch-wise features from the last layer out_patch_features_pt = model_pt(x_pt) return out_patch_features_pt def run_single_sample_inference(): # sample_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/FR/FR_ARCF004_20221029151346.mp4" sample_video_path = "/workspace/vjepa/videos/-WH-lxmGJVY_000005_000015.mp4" encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_giant_384', source='local') encoder.cuda().eval() hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local') print('Successfully loaded VJEPA2 model and preprocessor from local PyTorch Hub.') # Inference on video out_patch_features_pt = forward_vjepa_video( encoder, hf_transform, sample_video_path ) print( f""" Inference results on video: PyTorch output shape: {out_patch_features_pt.shape} """ ) def load_and_transform(video_path, pt_transform, num_frames=80): # 读取并统一采样为 num_frames(或尽量接近) vr = VideoReader(video_path) total_frames = len(vr) if total_frames < num_frames: frame_idx = np.arange(0, total_frames, max(1, total_frames // num_frames)) else: frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) video = vr.get_batch(frame_idx).asnumpy() # T x H x W x C video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W # pt_transform 返回 list(可能多视角),取第一个视图(或根据需要修改) tensor = pt_transform(video)[0] # C x T x H x W return tensor def forward_vjepa_multiview(encoder, pt_transform, video_paths, num_frames=80): """ video_paths: list of paths, e.g. [FR, LF, RF] 返回 encoder 输出(每个视角一个条目) """ with torch.inference_mode(): views = [] for p in video_paths: t = load_and_transform(p, pt_transform, num_frames=num_frames) views.append(t) # 堆叠为 batch: (V, C, T, H, W) views = torch.stack(views, dim=0).cuda() encoder = encoder.cuda() encoder.eval() out = encoder(views) # encoder 接受 B x C x T x H x W return out def run_multi_sample_inference(): FR_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/FR/FR_ARCF004_20221029151346.mp4" LF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/LF/LF_ARCF004_20221029151346.mp4" RF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/RF/RF_ARCF004_20221029151346.mp4" # 从本地 hub 加载(返回 encoder, predictor) encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_base_384', source='local') print('Successfully loaded encoder and predictor from local hub.') # 预处理:用 hub 提供的 preprocessor 获取 crop size,然后构建 PT transform hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local') # 三视角一起推理 video_paths = [LF_video_path, FR_video_path, RF_video_path] out = forward_vjepa_multiview(encoder, hf_transform, video_paths, num_frames=80) print(f"Encoder output for {len(video_paths)} views: {out.shape}") if __name__ == "__main__": # Run with: `python -m notebooks.vjepa2_demo` run_single_sample_inference()