#!/usr/bin/env python3 """ Extract video features using VideoMAE (pretrained on Kinetics-400). Process 16-frame video clips to capture temporal dynamics. Output: per-frame feature vectors aligned to 100Hz sensor data. """ import os import sys import json import glob import argparse import numpy as np import cv2 import torch DATASET_DIR = "${PULSE_ROOT}/dataset" MODEL_NAME = "${PULSE_ROOT}/models/videomae-base-kinetics" class VideoMAEFeatureExtractor: """Extract features using VideoMAE-Base (16-frame clips). Multi-GPU enabled.""" def __init__(self, device='cpu'): from transformers import VideoMAEModel, VideoMAEImageProcessor import torch.nn as nn self.device = device self.processor = VideoMAEImageProcessor.from_pretrained(MODEL_NAME) model = VideoMAEModel.from_pretrained(MODEL_NAME).to(device) model.eval() # Wrap with DataParallel if multiple GPUs available if torch.cuda.is_available() and torch.cuda.device_count() > 1: self.n_gpus = torch.cuda.device_count() print(f" Using DataParallel across {self.n_gpus} GPUs") self.model = nn.DataParallel(model) self.num_frames = model.config.num_frames self.feat_dim = model.config.hidden_size else: self.n_gpus = 1 self.model = model self.num_frames = model.config.num_frames self.feat_dim = model.config.hidden_size @torch.no_grad() def extract_clip(self, frames): """Extract feature from a single 16-frame clip. Args: frames: list of 16 RGB numpy arrays (H, W, 3) Returns: feature: numpy array (feat_dim,) - mean-pooled patch tokens """ # Pad/truncate to exactly num_frames if len(frames) < self.num_frames: frames = frames + [frames[-1]] * (self.num_frames - len(frames)) elif len(frames) > self.num_frames: # uniform sampling indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int) frames = [frames[i] for i in indices] inputs = self.processor(frames, return_tensors="pt") pixel_values = inputs["pixel_values"].to(self.device) outputs = self.model(pixel_values) # Average pool over all patch tokens feature = outputs.last_hidden_state.mean(dim=1).squeeze(0) # (768,) return feature.cpu().numpy() @torch.no_grad() def extract_clip_batch(self, clips): """Extract features from a batch of clips. Args: clips: list of clips, each is a list of 16 RGB frames Returns: features: numpy array (B, feat_dim) """ # Process each clip all_pixel_values = [] for frames in clips: if len(frames) < self.num_frames: frames = frames + [frames[-1]] * (self.num_frames - len(frames)) elif len(frames) > self.num_frames: indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int) frames = [frames[i] for i in indices] inputs = self.processor(frames, return_tensors="pt") all_pixel_values.append(inputs["pixel_values"]) batch = torch.cat(all_pixel_values, dim=0).to(self.device) outputs = self.model(batch) features = outputs.last_hidden_state.mean(dim=1) # (B, 768) return features.cpu().numpy() def find_scene_video(scenario_dir, vol, scenario): pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4") matches = glob.glob(pattern) return matches[0] if matches else None def extract_features_for_video(extractor, video_path, target_fps=100, clip_stride_sec=0.5, batch_size=4): """Extract VideoMAE features from a video. Strategy (fast): - Sequentially decode video ONCE, downsample to 8fps and store frames in RAM - Build clips by indexing into the in-memory frame array (no random seeks) """ import time t0 = time.time() cap = cv2.VideoCapture(video_path) video_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / video_fps # Read all frames sequentially, downsample to ~16fps (every video_fps/16 frame) decode_fps = 16 # we sample frames at this rate from the video decode_stride = max(1, int(round(video_fps / decode_fps))) print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s") print(f" Decoding sequentially with stride {decode_stride} (~{video_fps/decode_stride:.1f}fps)...") # Pre-resize to model input size during decoding to save memory # VideoMAE expects 224x224 target_size = 224 decoded_frames = [] # list of (H, W, 3) uint8 RGB arrays decoded_times = [] # corresponding timestamps in seconds frame_idx = 0 while True: ret, frame = cap.read() if not ret: break if frame_idx % decode_stride == 0: # Resize early to save memory resized = cv2.resize(frame, (target_size, target_size), interpolation=cv2.INTER_AREA) rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) decoded_frames.append(rgb) decoded_times.append(frame_idx / video_fps) frame_idx += 1 cap.release() decoded_frames = np.array(decoded_frames) # (N, 224, 224, 3) decoded_times = np.array(decoded_times) decode_time = time.time() - t0 print(f" Decoded {len(decoded_frames)} frames in {decode_time:.1f}s") # Build clips: each clip = 16 frames spanning ~1 second # Sample 16 consecutive frames from in-memory array frames_per_clip = 16 n_decoded = len(decoded_frames) if n_decoded < 4: return None # Each clip occupies 16 frames at ~16fps = 1 second clip_centers_sec = np.arange(0.5, duration - 0.5, clip_stride_sec) n_clips = len(clip_centers_sec) print(f" Building {n_clips} clips (stride={clip_stride_sec}s, {frames_per_clip} frames each)") all_features = [] clip_times = [] batch_clips = [] batch_times = [] t1 = time.time() for center_sec in clip_centers_sec: # Find decoded frames within ±0.5s window center_idx = np.searchsorted(decoded_times, center_sec) half = frames_per_clip // 2 start = max(0, center_idx - half) end = min(n_decoded, start + frames_per_clip) start = max(0, end - frames_per_clip) if end - start < 4: continue clip = list(decoded_frames[start:end]) # Pad if needed if len(clip) < frames_per_clip: clip = clip + [clip[-1]] * (frames_per_clip - len(clip)) batch_clips.append(clip) batch_times.append(center_sec) if len(batch_clips) >= batch_size: feats = extractor.extract_clip_batch(batch_clips) all_features.append(feats) clip_times.extend(batch_times) batch_clips = [] batch_times = [] if batch_clips: feats = extractor.extract_clip_batch(batch_clips) all_features.append(feats) clip_times.extend(batch_times) inference_time = time.time() - t1 print(f" Inference time: {inference_time:.1f}s ({len(clip_times)} clips)") if not all_features: return None features = np.concatenate(all_features, axis=0) # (N_clips, 768) clip_times = np.array(clip_times[:features.shape[0]]) # Interpolate to target_fps (100Hz) target_times = np.arange(0, duration, 1.0 / target_fps) n_target = len(target_times) from scipy.interpolate import interp1d if len(clip_times) < 2: interpolated = np.tile(features[0], (n_target, 1)) else: interp_func = interp1d( clip_times, features, axis=0, kind='linear', fill_value='extrapolate' ) interpolated = interp_func(target_times).astype(np.float32) print(f" Output: {interpolated.shape} @ {target_fps}Hz") return interpolated def main(): parser = argparse.ArgumentParser() parser.add_argument('--clip_stride', type=float, default=0.5, help='Clip extraction stride in seconds (default: 0.5)') parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--output_name', type=str, default='video_features_videomae_100hz.npy') args = parser.parse_args() device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu' print(f"Device: {device}") print(f"Loading VideoMAE from {MODEL_NAME}...") extractor = VideoMAEFeatureExtractor(device=device) print(f"Feature dim: {extractor.feat_dim}, num frames per clip: {extractor.num_frames}") processed = 0 skipped = 0 for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")): vol = os.path.basename(vol_dir) for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")): scenario = os.path.basename(scenario_dir) output_path = os.path.join(scenario_dir, args.output_name) if os.path.exists(output_path): print(f"[{vol}/{scenario}] exists, skip") skipped += 1 continue video_path = find_scene_video(scenario_dir, vol, scenario) if video_path is None: print(f"[{vol}/{scenario}] no video, skip") skipped += 1 continue print(f"\n[{vol}/{scenario}]") features = extract_features_for_video( extractor, video_path, clip_stride_sec=args.clip_stride, batch_size=args.batch_size, ) if features is not None: np.save(output_path, features) print(f" Saved: {output_path} ({features.shape})") processed += 1 else: print(f" FAILED") print(f"\nDone! Processed: {processed}, Skipped: {skipped}") if __name__ == '__main__': main()