PULSE-code / experiments /analysis /extract_videomae_features.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Extract video features using VideoMAE (pretrained on Kinetics-400).
Process 16-frame video clips to capture temporal dynamics.
Output: per-frame feature vectors aligned to 100Hz sensor data.
"""
import os
import sys
import json
import glob
import argparse
import numpy as np
import cv2
import torch
DATASET_DIR = "${PULSE_ROOT}/dataset"
MODEL_NAME = "${PULSE_ROOT}/models/videomae-base-kinetics"
class VideoMAEFeatureExtractor:
"""Extract features using VideoMAE-Base (16-frame clips). Multi-GPU enabled."""
def __init__(self, device='cpu'):
from transformers import VideoMAEModel, VideoMAEImageProcessor
import torch.nn as nn
self.device = device
self.processor = VideoMAEImageProcessor.from_pretrained(MODEL_NAME)
model = VideoMAEModel.from_pretrained(MODEL_NAME).to(device)
model.eval()
# Wrap with DataParallel if multiple GPUs available
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
self.n_gpus = torch.cuda.device_count()
print(f" Using DataParallel across {self.n_gpus} GPUs")
self.model = nn.DataParallel(model)
self.num_frames = model.config.num_frames
self.feat_dim = model.config.hidden_size
else:
self.n_gpus = 1
self.model = model
self.num_frames = model.config.num_frames
self.feat_dim = model.config.hidden_size
@torch.no_grad()
def extract_clip(self, frames):
"""Extract feature from a single 16-frame clip.
Args:
frames: list of 16 RGB numpy arrays (H, W, 3)
Returns:
feature: numpy array (feat_dim,) - mean-pooled patch tokens
"""
# Pad/truncate to exactly num_frames
if len(frames) < self.num_frames:
frames = frames + [frames[-1]] * (self.num_frames - len(frames))
elif len(frames) > self.num_frames:
# uniform sampling
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
frames = [frames[i] for i in indices]
inputs = self.processor(frames, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(self.device)
outputs = self.model(pixel_values)
# Average pool over all patch tokens
feature = outputs.last_hidden_state.mean(dim=1).squeeze(0) # (768,)
return feature.cpu().numpy()
@torch.no_grad()
def extract_clip_batch(self, clips):
"""Extract features from a batch of clips.
Args:
clips: list of clips, each is a list of 16 RGB frames
Returns:
features: numpy array (B, feat_dim)
"""
# Process each clip
all_pixel_values = []
for frames in clips:
if len(frames) < self.num_frames:
frames = frames + [frames[-1]] * (self.num_frames - len(frames))
elif len(frames) > self.num_frames:
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
frames = [frames[i] for i in indices]
inputs = self.processor(frames, return_tensors="pt")
all_pixel_values.append(inputs["pixel_values"])
batch = torch.cat(all_pixel_values, dim=0).to(self.device)
outputs = self.model(batch)
features = outputs.last_hidden_state.mean(dim=1) # (B, 768)
return features.cpu().numpy()
def find_scene_video(scenario_dir, vol, scenario):
pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4")
matches = glob.glob(pattern)
return matches[0] if matches else None
def extract_features_for_video(extractor, video_path, target_fps=100,
clip_stride_sec=0.5, batch_size=4):
"""Extract VideoMAE features from a video.
Strategy (fast):
- Sequentially decode video ONCE, downsample to 8fps and store frames in RAM
- Build clips by indexing into the in-memory frame array (no random seeks)
"""
import time
t0 = time.time()
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps
# Read all frames sequentially, downsample to ~16fps (every video_fps/16 frame)
decode_fps = 16 # we sample frames at this rate from the video
decode_stride = max(1, int(round(video_fps / decode_fps)))
print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s")
print(f" Decoding sequentially with stride {decode_stride} (~{video_fps/decode_stride:.1f}fps)...")
# Pre-resize to model input size during decoding to save memory
# VideoMAE expects 224x224
target_size = 224
decoded_frames = [] # list of (H, W, 3) uint8 RGB arrays
decoded_times = [] # corresponding timestamps in seconds
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_idx % decode_stride == 0:
# Resize early to save memory
resized = cv2.resize(frame, (target_size, target_size), interpolation=cv2.INTER_AREA)
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
decoded_frames.append(rgb)
decoded_times.append(frame_idx / video_fps)
frame_idx += 1
cap.release()
decoded_frames = np.array(decoded_frames) # (N, 224, 224, 3)
decoded_times = np.array(decoded_times)
decode_time = time.time() - t0
print(f" Decoded {len(decoded_frames)} frames in {decode_time:.1f}s")
# Build clips: each clip = 16 frames spanning ~1 second
# Sample 16 consecutive frames from in-memory array
frames_per_clip = 16
n_decoded = len(decoded_frames)
if n_decoded < 4:
return None
# Each clip occupies 16 frames at ~16fps = 1 second
clip_centers_sec = np.arange(0.5, duration - 0.5, clip_stride_sec)
n_clips = len(clip_centers_sec)
print(f" Building {n_clips} clips (stride={clip_stride_sec}s, {frames_per_clip} frames each)")
all_features = []
clip_times = []
batch_clips = []
batch_times = []
t1 = time.time()
for center_sec in clip_centers_sec:
# Find decoded frames within ±0.5s window
center_idx = np.searchsorted(decoded_times, center_sec)
half = frames_per_clip // 2
start = max(0, center_idx - half)
end = min(n_decoded, start + frames_per_clip)
start = max(0, end - frames_per_clip)
if end - start < 4:
continue
clip = list(decoded_frames[start:end])
# Pad if needed
if len(clip) < frames_per_clip:
clip = clip + [clip[-1]] * (frames_per_clip - len(clip))
batch_clips.append(clip)
batch_times.append(center_sec)
if len(batch_clips) >= batch_size:
feats = extractor.extract_clip_batch(batch_clips)
all_features.append(feats)
clip_times.extend(batch_times)
batch_clips = []
batch_times = []
if batch_clips:
feats = extractor.extract_clip_batch(batch_clips)
all_features.append(feats)
clip_times.extend(batch_times)
inference_time = time.time() - t1
print(f" Inference time: {inference_time:.1f}s ({len(clip_times)} clips)")
if not all_features:
return None
features = np.concatenate(all_features, axis=0) # (N_clips, 768)
clip_times = np.array(clip_times[:features.shape[0]])
# Interpolate to target_fps (100Hz)
target_times = np.arange(0, duration, 1.0 / target_fps)
n_target = len(target_times)
from scipy.interpolate import interp1d
if len(clip_times) < 2:
interpolated = np.tile(features[0], (n_target, 1))
else:
interp_func = interp1d(
clip_times, features, axis=0,
kind='linear', fill_value='extrapolate'
)
interpolated = interp_func(target_times).astype(np.float32)
print(f" Output: {interpolated.shape} @ {target_fps}Hz")
return interpolated
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--clip_stride', type=float, default=0.5,
help='Clip extraction stride in seconds (default: 0.5)')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--output_name', type=str, default='video_features_videomae_100hz.npy')
args = parser.parse_args()
device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu'
print(f"Device: {device}")
print(f"Loading VideoMAE from {MODEL_NAME}...")
extractor = VideoMAEFeatureExtractor(device=device)
print(f"Feature dim: {extractor.feat_dim}, num frames per clip: {extractor.num_frames}")
processed = 0
skipped = 0
for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
vol = os.path.basename(vol_dir)
for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
scenario = os.path.basename(scenario_dir)
output_path = os.path.join(scenario_dir, args.output_name)
if os.path.exists(output_path):
print(f"[{vol}/{scenario}] exists, skip")
skipped += 1
continue
video_path = find_scene_video(scenario_dir, vol, scenario)
if video_path is None:
print(f"[{vol}/{scenario}] no video, skip")
skipped += 1
continue
print(f"\n[{vol}/{scenario}]")
features = extract_features_for_video(
extractor, video_path,
clip_stride_sec=args.clip_stride,
batch_size=args.batch_size,
)
if features is not None:
np.save(output_path, features)
print(f" Saved: {output_path} ({features.shape})")
processed += 1
else:
print(f" FAILED")
print(f"\nDone! Processed: {processed}, Skipped: {skipped}")
if __name__ == '__main__':
main()