| | import torch |
| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | from vljepa.config import Config |
| | from vljepa.models import VLJepa |
| | from vljepa.utils import nms |
| |
|
| | def load_model(checkpoint_path, device="cpu"): |
| | config = Config() |
| | config.device = device |
| | model = VLJepa(config) |
| | |
| | print(f"Loading weights from {checkpoint_path}...") |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
| | model.predictor.load_state_dict(checkpoint["predictor_state_dict"]) |
| | model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"]) |
| | |
| | model.eval() |
| | return model, config |
| |
|
| | def extract_frames(video_path, num_frames=16): |
| | cap = cv2.VideoCapture(video_path) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | if total_frames <= 0: |
| | return [] |
| | |
| | indices = np.linspace(0, total_frames - 1, num_frames).astype(int) |
| | frames = [] |
| | for idx in indices: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| | ret, frame = cap.read() |
| | if ret: |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(frame) |
| | cap.release() |
| | return frames |
| |
|
| | def main(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | checkpoint_path = "best.pth" |
| | video_path = "sample_video.mp4" |
| | query = "a person is opening a door" |
| | |
| | model, config = load_model(checkpoint_path, device) |
| | |
| | |
| | |
| | print(f"Ready for inference on {device}.") |
| | print(f"Model architecture: {config.clip_model} + {config.predictor_model} (LoRA) + {config.text_model}") |
| | |
| | |
| | query_tokens = model.query_encoder.tokenize([query], device=device) |
| | |
| | |
| | with torch.no_grad(): |
| | text_embedding = model.encode_text([query], device=device) |
| | |
| | print(f"Query: '{query}'") |
| | print(f"Text embedding shape: {text_embedding.shape}") |
| | print("\nTo perform full temporal localization, use the infer.py script which implements sliding window and NMS.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|