| import warnings |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
| import json |
| import os |
| import subprocess |
| import time |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from decord import VideoReader |
| from transformers import AutoModel, AutoVideoProcessor |
|
|
| import src.datasets.utils.video.transforms as video_transforms |
| import src.datasets.utils.video.volume_transforms as volume_transforms |
| from src.models.attentive_pooler import AttentiveClassifier |
| from src.models.vision_transformer import vit_giant_xformers_rope, vit_base |
|
|
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
| def load_pretrained_vjepa_pt_weights(model, pretrained_weights): |
| |
| |
| pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"] |
| pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} |
| pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()} |
| msg = model.load_state_dict(pretrained_dict, strict=False) |
| print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) |
|
|
|
|
| def build_pt_video_transform(img_size): |
| short_side_size = int(256.0 / 224 * img_size) |
| |
| eval_transform = video_transforms.Compose( |
| [ |
| video_transforms.Resize(short_side_size, interpolation="bilinear"), |
| video_transforms.CenterCrop(size=(img_size, img_size)), |
| volume_transforms.ClipToTensor(), |
| video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
| ] |
| ) |
| return eval_transform |
|
|
|
|
| def get_video(sample_video_path, num_frames=80): |
| vr = VideoReader(sample_video_path) |
| total_frames = len(vr) |
| |
| if total_frames < num_frames: |
| frame_idx = np.arange(0, total_frames, 2) |
| else: |
| frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
| video = vr.get_batch(frame_idx).asnumpy() |
| return video |
|
|
|
|
| def forward_vjepa_video(model_pt, pt_transform, sample_video_path): |
| |
| with torch.inference_mode(): |
| |
| video = get_video(sample_video_path) |
| video = torch.from_numpy(video).permute(0, 3, 1, 2) |
| print(video.shape) |
| x_pt = pt_transform(video)[0].cuda().unsqueeze(0) |
| print(x_pt.shape) |
| |
| out_patch_features_pt = model_pt(x_pt) |
|
|
| return out_patch_features_pt |
|
|
|
|
|
|
| def run_single_sample_inference(): |
| |
| sample_video_path = "/workspace/vjepa/videos/-WH-lxmGJVY_000005_000015.mp4" |
|
|
| encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_giant_384', source='local') |
| encoder.cuda().eval() |
|
|
| hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local') |
| print('Successfully loaded VJEPA2 model and preprocessor from local PyTorch Hub.') |
| |
| out_patch_features_pt = forward_vjepa_video( |
| encoder, hf_transform, sample_video_path |
| ) |
|
|
| print( |
| f""" |
| Inference results on video: |
| PyTorch output shape: {out_patch_features_pt.shape} |
| """ |
| ) |
|
|
|
|
| def load_and_transform(video_path, pt_transform, num_frames=80): |
| |
| vr = VideoReader(video_path) |
| total_frames = len(vr) |
| if total_frames < num_frames: |
| frame_idx = np.arange(0, total_frames, max(1, total_frames // num_frames)) |
| else: |
| frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
| video = vr.get_batch(frame_idx).asnumpy() |
| video = torch.from_numpy(video).permute(0, 3, 1, 2) |
| |
| tensor = pt_transform(video)[0] |
| return tensor |
|
|
| def forward_vjepa_multiview(encoder, pt_transform, video_paths, num_frames=80): |
| """ |
| video_paths: list of paths, e.g. [FR, LF, RF] |
| 返回 encoder 输出(每个视角一个条目) |
| """ |
| with torch.inference_mode(): |
| views = [] |
| for p in video_paths: |
| t = load_and_transform(p, pt_transform, num_frames=num_frames) |
| views.append(t) |
| |
| views = torch.stack(views, dim=0).cuda() |
| encoder = encoder.cuda() |
| encoder.eval() |
| out = encoder(views) |
| return out |
|
|
| def run_multi_sample_inference(): |
| FR_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/FR/FR_ARCF004_20221029151346.mp4" |
| LF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/LF/LF_ARCF004_20221029151346.mp4" |
| RF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/RF/RF_ARCF004_20221029151346.mp4" |
|
|
| |
| encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_base_384', source='local') |
| print('Successfully loaded encoder and predictor from local hub.') |
|
|
| |
| hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local') |
|
|
| |
| video_paths = [LF_video_path, FR_video_path, RF_video_path] |
| out = forward_vjepa_multiview(encoder, hf_transform, video_paths, num_frames=80) |
|
|
| print(f"Encoder output for {len(video_paths)} views: {out.shape}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| run_single_sample_inference() |
|
|