File size: 10,067 Bytes
e013b9e 5578f24 e013b9e 5578f24 e013b9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
from typing import Dict, List, Any, Union
import torch
import numpy as np
import base64
import io
import tempfile
import os
import transformers
import logging
from pathlib import Path
print("transformers version ", transformers.__version__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Custom HuggingFace Inference Endpoint Handler for V-JEPA2 Video Embeddings.
This handler processes videos and returns pooled embeddings suitable for
similarity search and vector databases like LanceDB.
Features:
- Batch processing support for efficient inference
- Handles variable-length videos via uniform frame sampling
- Supports video URLs and base64-encoded videos
- Returns 1408-dimensional pooled embeddings
"""
def __init__(self, path: str = ""):
"""
Initialize the V-JEPA2 model and processor.
Args:
path: Path to the model weights (provided by HF Inference Endpoints)
"""
try:
from transformers import AutoVideoProcessor, AutoModel
from torchcodec.decoders import VideoDecoder
logger.info(f"Loading V-JEPA2 model from {path}")
# Determine device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load model without the classification head to get embeddings
# We use AutoModel instead of AutoModelForVideoClassification
self.model = AutoModel.from_pretrained(path).to(self.device)
self.processor = AutoVideoProcessor.from_pretrained(path)
# Set model to evaluation mode
self.model.eval()
# Store model config
self.frames_per_clip = getattr(self.model.config, 'frames_per_clip', 64)
self.hidden_size = getattr(self.model.config, 'hidden_size', 1408)
logger.info(f"Model loaded successfully. Frames per clip: {self.frames_per_clip}, Hidden size: {self.hidden_size}")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def _load_video_from_url(self, video_url: str) -> np.ndarray:
"""
Load video from URL and sample frames.
Args:
video_url: URL to the video file
Returns:
Video tensor with shape (frames, channels, height, width)
"""
from torchcodec.decoders import VideoDecoder
try:
vr = VideoDecoder(video_url)
total_frames = len(vr)
# Uniform sampling to get exactly frames_per_clip frames
if total_frames < self.frames_per_clip:
logger.warning(f"Video has only {total_frames} frames, less than required {self.frames_per_clip}. Repeating frames.")
# Repeat frames to reach required count
frame_indices = np.tile(np.arange(total_frames),
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
else:
# Uniform sampling across the video
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
video = vr.get_frames_at(indices=frame_indices).data
return video
except Exception as e:
logger.error(f"Error loading video from URL {video_url}: {str(e)}")
raise
def _load_video_from_base64(self, video_b64: str) -> np.ndarray:
"""
Load video from base64-encoded data.
Args:
video_b64: Base64-encoded video data
Returns:
Video tensor with shape (frames, channels, height, width)
"""
from torchcodec.decoders import VideoDecoder
try:
# Decode base64
video_bytes = base64.b64decode(video_b64)
# Save to temporary file (torchcodec requires file path)
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(video_bytes)
tmp_path = tmp_file.name
try:
vr = VideoDecoder(tmp_path)
total_frames = len(vr)
# Uniform sampling
if total_frames < self.frames_per_clip:
frame_indices = np.tile(np.arange(total_frames),
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
else:
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
video = vr.get_frames_at(indices=frame_indices).data
return video
finally:
# Clean up temporary file
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Error loading video from base64: {str(e)}")
raise
def _extract_embeddings(self, videos: List[np.ndarray]) -> np.ndarray:
"""
Extract pooled embeddings from a batch of videos.
Args:
videos: List of video tensors
Returns:
Numpy array of shape (batch_size, hidden_size) containing pooled embeddings
"""
try:
# Process videos through the processor
inputs = self.processor(videos, return_tensors="pt").to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Extract last hidden state and pool
# Shape: (batch_size, sequence_length, hidden_size)
last_hidden_state = outputs.last_hidden_state
# Mean pooling across sequence dimension
# Shape: (batch_size, hidden_size)
pooled_embeddings = last_hidden_state.mean(dim=1)
# Convert to numpy
embeddings = pooled_embeddings.cpu().numpy()
return embeddings
except Exception as e:
logger.error(f"Error extracting embeddings: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request.
Expected input formats:
1. Single video URL:
{"inputs": "https://example.com/video.mp4"}
2. Batch of video URLs:
{"inputs": ["url1", "url2", "url3"]}
3. Base64-encoded video:
{"inputs": "base64_encoded_string", "encoding": "base64"}
4. Batch with mixed formats:
{"inputs": [...], "batch_size": 4}
Returns:
List of dictionaries containing embeddings:
[{"embedding": [1408-dim vector], "shape": [1408]}]
"""
try:
# Extract inputs
inputs = data.get("inputs")
encoding = data.get("encoding", "url")
if inputs is None:
raise ValueError("No 'inputs' provided in request data")
# Handle single input vs batch
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
raise ValueError(f"'inputs' must be a string or list, got {type(inputs)}")
logger.info(f"Processing {len(inputs)} video(s)")
# Load videos
videos = []
for idx, inp in enumerate(inputs):
try:
if encoding == "base64":
video = self._load_video_from_base64(inp)
else: # Default to URL
video = self._load_video_from_url(inp)
videos.append(video)
except Exception as e:
logger.error(f"Error loading video {idx}: {str(e)}")
# Return error for this specific video
videos.append(None)
# Filter out failed videos and track their indices
valid_videos = []
valid_indices = []
for idx, video in enumerate(videos):
if video is not None:
valid_videos.append(video)
valid_indices.append(idx)
if not valid_videos:
raise ValueError("No valid videos could be loaded")
# Extract embeddings for valid videos
embeddings = self._extract_embeddings(valid_videos)
# Prepare results
results = [None] * len(inputs)
for valid_idx, embedding in zip(valid_indices, embeddings):
results[valid_idx] = {
"embedding": embedding.tolist(),
"shape": list(embedding.shape),
"status": "success"
}
# Fill in errors for failed videos
for idx in range(len(inputs)):
if results[idx] is None:
results[idx] = {
"embedding": None,
"shape": None,
"status": "error",
"error": "Failed to load video"
}
logger.info(f"Successfully processed {len(valid_videos)}/{len(inputs)} videos")
return results
except Exception as e:
logger.error(f"Error in __call__: {str(e)}")
return [{"error": str(e), "status": "error"}]
|