vjepa-docker-image / test_fun.py
charsin's picture
Upload test_fun.py with huggingface_hub
3d2ff6a verified
Raw
History Blame Contribute Delete
6.19 kB
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import json
import os
import subprocess
import time
import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoModel, AutoVideoProcessor
import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_giant_xformers_rope, vit_base
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
# Load weights of the VJEPA2 encoder
# The PyTorch state_dict is already preprocessed to have the right key names
pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
msg = model.load_state_dict(pretrained_dict, strict=False)
print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
def build_pt_video_transform(img_size):
short_side_size = int(256.0 / 224 * img_size)
# Eval transform has no random cropping nor flip
eval_transform = video_transforms.Compose(
[
video_transforms.Resize(short_side_size, interpolation="bilinear"),
video_transforms.CenterCrop(size=(img_size, img_size)),
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]
)
return eval_transform
def get_video(sample_video_path, num_frames=80):
vr = VideoReader(sample_video_path)
total_frames = len(vr)
# Choose evenly spaced frames, limited by available frames
if total_frames < num_frames:
frame_idx = np.arange(0, total_frames, 2)
else:
frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
video = vr.get_batch(frame_idx).asnumpy()
return video
def forward_vjepa_video(model_pt, pt_transform, sample_video_path):
# Run a sample inference with VJEPA
with torch.inference_mode():
# Read and pre-process the image
video = get_video(sample_video_path) # T x H x W x C
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
print(video.shape)
x_pt = pt_transform(video)[0].cuda().unsqueeze(0)
print(x_pt.shape)
# Extract the patch-wise features from the last layer
out_patch_features_pt = model_pt(x_pt)
return out_patch_features_pt
def run_single_sample_inference():
# sample_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/FR/FR_ARCF004_20221029151346.mp4"
sample_video_path = "/workspace/vjepa/videos/-WH-lxmGJVY_000005_000015.mp4"
encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_giant_384', source='local')
encoder.cuda().eval()
hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local')
print('Successfully loaded VJEPA2 model and preprocessor from local PyTorch Hub.')
# Inference on video
out_patch_features_pt = forward_vjepa_video(
encoder, hf_transform, sample_video_path
)
print(
f"""
Inference results on video:
PyTorch output shape: {out_patch_features_pt.shape}
"""
)
def load_and_transform(video_path, pt_transform, num_frames=80):
# 读取并统一采样为 num_frames(或尽量接近)
vr = VideoReader(video_path)
total_frames = len(vr)
if total_frames < num_frames:
frame_idx = np.arange(0, total_frames, max(1, total_frames // num_frames))
else:
frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
video = vr.get_batch(frame_idx).asnumpy() # T x H x W x C
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
# pt_transform 返回 list(可能多视角),取第一个视图(或根据需要修改)
tensor = pt_transform(video)[0] # C x T x H x W
return tensor
def forward_vjepa_multiview(encoder, pt_transform, video_paths, num_frames=80):
"""
video_paths: list of paths, e.g. [FR, LF, RF]
返回 encoder 输出(每个视角一个条目)
"""
with torch.inference_mode():
views = []
for p in video_paths:
t = load_and_transform(p, pt_transform, num_frames=num_frames)
views.append(t)
# 堆叠为 batch: (V, C, T, H, W)
views = torch.stack(views, dim=0).cuda()
encoder = encoder.cuda()
encoder.eval()
out = encoder(views) # encoder 接受 B x C x T x H x W
return out
def run_multi_sample_inference():
FR_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/FR/FR_ARCF004_20221029151346.mp4"
LF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/LF/LF_ARCF004_20221029151346.mp4"
RF_video_path = "/mnt/data2/tzx/workspace/auto/pipeline/drivelaw/data/escape_data/danger_8hz/RF/RF_ARCF004_20221029151346.mp4"
# 从本地 hub 加载(返回 encoder, predictor)
encoder, predictor = torch.hub.load('/workspace/vjepa', 'vjepa2_1_vit_base_384', source='local')
print('Successfully loaded encoder and predictor from local hub.')
# 预处理:用 hub 提供的 preprocessor 获取 crop size,然后构建 PT transform
hf_transform = torch.hub.load('/workspace/vjepa', 'vjepa2_preprocessor', source='local')
# 三视角一起推理
video_paths = [LF_video_path, FR_video_path, RF_video_path]
out = forward_vjepa_multiview(encoder, hf_transform, video_paths, num_frames=80)
print(f"Encoder output for {len(video_paths)} views: {out.shape}")
if __name__ == "__main__":
# Run with: `python -m notebooks.vjepa2_demo`
run_single_sample_inference()