|
|
from typing import Dict, List, Any, Union |
|
|
import torch |
|
|
import numpy as np |
|
|
import base64 |
|
|
import io |
|
|
import tempfile |
|
|
import os |
|
|
import transformers |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
print("transformers version ", transformers.__version__) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom HuggingFace Inference Endpoint Handler for V-JEPA2 Video Embeddings. |
|
|
|
|
|
This handler processes videos and returns pooled embeddings suitable for |
|
|
similarity search and vector databases like LanceDB. |
|
|
|
|
|
Features: |
|
|
- Batch processing support for efficient inference |
|
|
- Handles variable-length videos via uniform frame sampling |
|
|
- Supports video URLs and base64-encoded videos |
|
|
- Returns 1408-dimensional pooled embeddings |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the V-JEPA2 model and processor. |
|
|
|
|
|
Args: |
|
|
path: Path to the model weights (provided by HF Inference Endpoints) |
|
|
""" |
|
|
try: |
|
|
from transformers import AutoVideoProcessor, AutoModel |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
|
|
|
logger.info(f"Loading V-JEPA2 model from {path}") |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
|
|
|
self.model = AutoModel.from_pretrained(path).to(self.device) |
|
|
self.processor = AutoVideoProcessor.from_pretrained(path) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.frames_per_clip = getattr(self.model.config, 'frames_per_clip', 64) |
|
|
self.hidden_size = getattr(self.model.config, 'hidden_size', 1408) |
|
|
|
|
|
logger.info(f"Model loaded successfully. Frames per clip: {self.frames_per_clip}, Hidden size: {self.hidden_size}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _load_video_from_url(self, video_url: str) -> np.ndarray: |
|
|
""" |
|
|
Load video from URL and sample frames. |
|
|
|
|
|
Args: |
|
|
video_url: URL to the video file |
|
|
|
|
|
Returns: |
|
|
Video tensor with shape (frames, channels, height, width) |
|
|
""" |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
|
|
|
try: |
|
|
vr = VideoDecoder(video_url) |
|
|
total_frames = len(vr) |
|
|
|
|
|
|
|
|
if total_frames < self.frames_per_clip: |
|
|
logger.warning(f"Video has only {total_frames} frames, less than required {self.frames_per_clip}. Repeating frames.") |
|
|
|
|
|
frame_indices = np.tile(np.arange(total_frames), |
|
|
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip] |
|
|
else: |
|
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int) |
|
|
|
|
|
video = vr.get_frames_at(indices=frame_indices).data |
|
|
return video |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading video from URL {video_url}: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _load_video_from_base64(self, video_b64: str) -> np.ndarray: |
|
|
""" |
|
|
Load video from base64-encoded data. |
|
|
|
|
|
Args: |
|
|
video_b64: Base64-encoded video data |
|
|
|
|
|
Returns: |
|
|
Video tensor with shape (frames, channels, height, width) |
|
|
""" |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
|
|
|
try: |
|
|
|
|
|
video_bytes = base64.b64decode(video_b64) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file: |
|
|
tmp_file.write(video_bytes) |
|
|
tmp_path = tmp_file.name |
|
|
|
|
|
try: |
|
|
vr = VideoDecoder(tmp_path) |
|
|
total_frames = len(vr) |
|
|
|
|
|
|
|
|
if total_frames < self.frames_per_clip: |
|
|
frame_indices = np.tile(np.arange(total_frames), |
|
|
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip] |
|
|
else: |
|
|
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int) |
|
|
|
|
|
video = vr.get_frames_at(indices=frame_indices).data |
|
|
return video |
|
|
finally: |
|
|
|
|
|
os.unlink(tmp_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading video from base64: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _extract_embeddings(self, videos: List[np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Extract pooled embeddings from a batch of videos. |
|
|
|
|
|
Args: |
|
|
videos: List of video tensors |
|
|
|
|
|
Returns: |
|
|
Numpy array of shape (batch_size, hidden_size) containing pooled embeddings |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = self.processor(videos, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs, output_hidden_states=True) |
|
|
|
|
|
|
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
|
|
|
pooled_embeddings = last_hidden_state.mean(dim=1) |
|
|
|
|
|
|
|
|
embeddings = pooled_embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error extracting embeddings: {str(e)}") |
|
|
raise |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process inference request. |
|
|
|
|
|
Expected input formats: |
|
|
1. Single video URL: |
|
|
{"inputs": "https://example.com/video.mp4"} |
|
|
|
|
|
2. Batch of video URLs: |
|
|
{"inputs": ["url1", "url2", "url3"]} |
|
|
|
|
|
3. Base64-encoded video: |
|
|
{"inputs": "base64_encoded_string", "encoding": "base64"} |
|
|
|
|
|
4. Batch with mixed formats: |
|
|
{"inputs": [...], "batch_size": 4} |
|
|
|
|
|
Returns: |
|
|
List of dictionaries containing embeddings: |
|
|
[{"embedding": [1408-dim vector], "shape": [1408]}] |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs") |
|
|
encoding = data.get("encoding", "url") |
|
|
|
|
|
if inputs is None: |
|
|
raise ValueError("No 'inputs' provided in request data") |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
elif not isinstance(inputs, list): |
|
|
raise ValueError(f"'inputs' must be a string or list, got {type(inputs)}") |
|
|
|
|
|
logger.info(f"Processing {len(inputs)} video(s)") |
|
|
|
|
|
|
|
|
videos = [] |
|
|
for idx, inp in enumerate(inputs): |
|
|
try: |
|
|
if encoding == "base64": |
|
|
video = self._load_video_from_base64(inp) |
|
|
else: |
|
|
video = self._load_video_from_url(inp) |
|
|
videos.append(video) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading video {idx}: {str(e)}") |
|
|
|
|
|
videos.append(None) |
|
|
|
|
|
|
|
|
valid_videos = [] |
|
|
valid_indices = [] |
|
|
for idx, video in enumerate(videos): |
|
|
if video is not None: |
|
|
valid_videos.append(video) |
|
|
valid_indices.append(idx) |
|
|
|
|
|
if not valid_videos: |
|
|
raise ValueError("No valid videos could be loaded") |
|
|
|
|
|
|
|
|
embeddings = self._extract_embeddings(valid_videos) |
|
|
|
|
|
|
|
|
results = [None] * len(inputs) |
|
|
for valid_idx, embedding in zip(valid_indices, embeddings): |
|
|
results[valid_idx] = { |
|
|
"embedding": embedding.tolist(), |
|
|
"shape": list(embedding.shape), |
|
|
"status": "success" |
|
|
} |
|
|
|
|
|
|
|
|
for idx in range(len(inputs)): |
|
|
if results[idx] is None: |
|
|
results[idx] = { |
|
|
"embedding": None, |
|
|
"shape": None, |
|
|
"status": "error", |
|
|
"error": "Failed to load video" |
|
|
} |
|
|
|
|
|
logger.info(f"Successfully processed {len(valid_videos)}/{len(inputs)} videos") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in __call__: {str(e)}") |
|
|
return [{"error": str(e), "status": "error"}] |
|
|
|