ayushexel commited on
Commit
e013b9e
·
verified ·
1 Parent(s): 627a1cf

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +265 -0
  2. requirements.txt +15 -0
handler.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union
2
+ import torch
3
+ import numpy as np
4
+ import base64
5
+ import io
6
+ import tempfile
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class EndpointHandler:
17
+ """
18
+ Custom HuggingFace Inference Endpoint Handler for V-JEPA2 Video Embeddings.
19
+
20
+ This handler processes videos and returns pooled embeddings suitable for
21
+ similarity search and vector databases like LanceDB.
22
+
23
+ Features:
24
+ - Batch processing support for efficient inference
25
+ - Handles variable-length videos via uniform frame sampling
26
+ - Supports video URLs and base64-encoded videos
27
+ - Returns 1408-dimensional pooled embeddings
28
+ """
29
+
30
+ def __init__(self, path: str = ""):
31
+ """
32
+ Initialize the V-JEPA2 model and processor.
33
+
34
+ Args:
35
+ path: Path to the model weights (provided by HF Inference Endpoints)
36
+ """
37
+ try:
38
+ from transformers import AutoVideoProcessor, AutoModel
39
+ from torchcodec.decoders import VideoDecoder
40
+
41
+ logger.info(f"Loading V-JEPA2 model from {path}")
42
+
43
+ # Determine device
44
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ logger.info(f"Using device: {self.device}")
46
+
47
+ # Load model without the classification head to get embeddings
48
+ # We use AutoModel instead of AutoModelForVideoClassification
49
+ self.model = AutoModel.from_pretrained(path).to(self.device)
50
+ self.processor = AutoVideoProcessor.from_pretrained(path)
51
+
52
+ # Set model to evaluation mode
53
+ self.model.eval()
54
+
55
+ # Store model config
56
+ self.frames_per_clip = getattr(self.model.config, 'frames_per_clip', 64)
57
+ self.hidden_size = getattr(self.model.config, 'hidden_size', 1408)
58
+
59
+ logger.info(f"Model loaded successfully. Frames per clip: {self.frames_per_clip}, Hidden size: {self.hidden_size}")
60
+
61
+ except Exception as e:
62
+ logger.error(f"Error initializing model: {str(e)}")
63
+ raise
64
+
65
+ def _load_video_from_url(self, video_url: str) -> np.ndarray:
66
+ """
67
+ Load video from URL and sample frames.
68
+
69
+ Args:
70
+ video_url: URL to the video file
71
+
72
+ Returns:
73
+ Video tensor with shape (frames, channels, height, width)
74
+ """
75
+ from torchcodec.decoders import VideoDecoder
76
+
77
+ try:
78
+ vr = VideoDecoder(video_url)
79
+ total_frames = len(vr)
80
+
81
+ # Uniform sampling to get exactly frames_per_clip frames
82
+ if total_frames < self.frames_per_clip:
83
+ logger.warning(f"Video has only {total_frames} frames, less than required {self.frames_per_clip}. Repeating frames.")
84
+ # Repeat frames to reach required count
85
+ frame_indices = np.tile(np.arange(total_frames),
86
+ (self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
87
+ else:
88
+ # Uniform sampling across the video
89
+ frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
90
+
91
+ video = vr.get_frames_at(indices=frame_indices).data
92
+ return video
93
+
94
+ except Exception as e:
95
+ logger.error(f"Error loading video from URL {video_url}: {str(e)}")
96
+ raise
97
+
98
+ def _load_video_from_base64(self, video_b64: str) -> np.ndarray:
99
+ """
100
+ Load video from base64-encoded data.
101
+
102
+ Args:
103
+ video_b64: Base64-encoded video data
104
+
105
+ Returns:
106
+ Video tensor with shape (frames, channels, height, width)
107
+ """
108
+ from torchcodec.decoders import VideoDecoder
109
+
110
+ try:
111
+ # Decode base64
112
+ video_bytes = base64.b64decode(video_b64)
113
+
114
+ # Save to temporary file (torchcodec requires file path)
115
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
116
+ tmp_file.write(video_bytes)
117
+ tmp_path = tmp_file.name
118
+
119
+ try:
120
+ vr = VideoDecoder(tmp_path)
121
+ total_frames = len(vr)
122
+
123
+ # Uniform sampling
124
+ if total_frames < self.frames_per_clip:
125
+ frame_indices = np.tile(np.arange(total_frames),
126
+ (self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
127
+ else:
128
+ frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
129
+
130
+ video = vr.get_frames_at(indices=frame_indices).data
131
+ return video
132
+ finally:
133
+ # Clean up temporary file
134
+ os.unlink(tmp_path)
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error loading video from base64: {str(e)}")
138
+ raise
139
+
140
+ def _extract_embeddings(self, videos: List[np.ndarray]) -> np.ndarray:
141
+ """
142
+ Extract pooled embeddings from a batch of videos.
143
+
144
+ Args:
145
+ videos: List of video tensors
146
+
147
+ Returns:
148
+ Numpy array of shape (batch_size, hidden_size) containing pooled embeddings
149
+ """
150
+ try:
151
+ # Process videos through the processor
152
+ inputs = self.processor(videos, return_tensors="pt").to(self.device)
153
+
154
+ # Run inference
155
+ with torch.no_grad():
156
+ outputs = self.model(**inputs, output_hidden_states=True)
157
+
158
+ # Extract last hidden state and pool
159
+ # Shape: (batch_size, sequence_length, hidden_size)
160
+ last_hidden_state = outputs.last_hidden_state
161
+
162
+ # Mean pooling across sequence dimension
163
+ # Shape: (batch_size, hidden_size)
164
+ pooled_embeddings = last_hidden_state.mean(dim=1)
165
+
166
+ # Convert to numpy
167
+ embeddings = pooled_embeddings.cpu().numpy()
168
+
169
+ return embeddings
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error extracting embeddings: {str(e)}")
173
+ raise
174
+
175
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
176
+ """
177
+ Process inference request.
178
+
179
+ Expected input formats:
180
+ 1. Single video URL:
181
+ {"inputs": "https://example.com/video.mp4"}
182
+
183
+ 2. Batch of video URLs:
184
+ {"inputs": ["url1", "url2", "url3"]}
185
+
186
+ 3. Base64-encoded video:
187
+ {"inputs": "base64_encoded_string", "encoding": "base64"}
188
+
189
+ 4. Batch with mixed formats:
190
+ {"inputs": [...], "batch_size": 4}
191
+
192
+ Returns:
193
+ List of dictionaries containing embeddings:
194
+ [{"embedding": [1408-dim vector], "shape": [1408]}]
195
+ """
196
+ try:
197
+ # Extract inputs
198
+ inputs = data.get("inputs")
199
+ encoding = data.get("encoding", "url")
200
+
201
+ if inputs is None:
202
+ raise ValueError("No 'inputs' provided in request data")
203
+
204
+ # Handle single input vs batch
205
+ if isinstance(inputs, str):
206
+ inputs = [inputs]
207
+ elif not isinstance(inputs, list):
208
+ raise ValueError(f"'inputs' must be a string or list, got {type(inputs)}")
209
+
210
+ logger.info(f"Processing {len(inputs)} video(s)")
211
+
212
+ # Load videos
213
+ videos = []
214
+ for idx, inp in enumerate(inputs):
215
+ try:
216
+ if encoding == "base64":
217
+ video = self._load_video_from_base64(inp)
218
+ else: # Default to URL
219
+ video = self._load_video_from_url(inp)
220
+ videos.append(video)
221
+ except Exception as e:
222
+ logger.error(f"Error loading video {idx}: {str(e)}")
223
+ # Return error for this specific video
224
+ videos.append(None)
225
+
226
+ # Filter out failed videos and track their indices
227
+ valid_videos = []
228
+ valid_indices = []
229
+ for idx, video in enumerate(videos):
230
+ if video is not None:
231
+ valid_videos.append(video)
232
+ valid_indices.append(idx)
233
+
234
+ if not valid_videos:
235
+ raise ValueError("No valid videos could be loaded")
236
+
237
+ # Extract embeddings for valid videos
238
+ embeddings = self._extract_embeddings(valid_videos)
239
+
240
+ # Prepare results
241
+ results = [None] * len(inputs)
242
+ for valid_idx, embedding in zip(valid_indices, embeddings):
243
+ results[valid_idx] = {
244
+ "embedding": embedding.tolist(),
245
+ "shape": list(embedding.shape),
246
+ "status": "success"
247
+ }
248
+
249
+ # Fill in errors for failed videos
250
+ for idx in range(len(inputs)):
251
+ if results[idx] is None:
252
+ results[idx] = {
253
+ "embedding": None,
254
+ "shape": None,
255
+ "status": "error",
256
+ "error": "Failed to load video"
257
+ }
258
+
259
+ logger.info(f"Successfully processed {len(valid_videos)}/{len(inputs)} videos")
260
+
261
+ return results
262
+
263
+ except Exception as e:
264
+ logger.error(f"Error in __call__: {str(e)}")
265
+ return [{"error": str(e), "status": "error"}]
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # V-JEPA2 Inference Endpoint Requirements
2
+ # Install latest transformers from git for V-JEPA2 support
3
+ git+https://github.com/huggingface/transformers
4
+
5
+ # Core dependencies
6
+ torch>=2.0.0
7
+ torchvision>=0.15.0
8
+ numpy>=1.24.0
9
+
10
+ # Video processing
11
+ torchcodec>=0.1.0
12
+
13
+ # Additional utilities
14
+ Pillow>=10.0.0
15
+ requests>=2.31.0