peterproofpath commited on
Commit
79053cb
·
verified ·
1 Parent(s): 2966b9e

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +210 -0
  2. requirements.txt +17 -0
handler.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources)
4
+
5
+ For ProofPath video assessment - extracts motion features from skill demonstration videos.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Optional
9
+ import torch
10
+ import numpy as np
11
+ import base64
12
+ import io
13
+ import tempfile
14
+ import os
15
+
16
+
17
+ class EndpointHandler:
18
+ def __init__(self, path: str = ""):
19
+ """
20
+ Initialize V-JEPA 2 model for video feature extraction.
21
+
22
+ Args:
23
+ path: Path to the model directory (provided by HF Inference Endpoints)
24
+ """
25
+ from transformers import AutoVideoProcessor, AutoModel
26
+
27
+ # Use the model path provided by the endpoint, or default to HF hub
28
+ model_id = path if path else "facebook/vjepa2-vitl-fpc64-256"
29
+
30
+ # Determine device
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Load processor and model
34
+ self.processor = AutoVideoProcessor.from_pretrained(model_id)
35
+ self.model = AutoModel.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ device_map="auto" if torch.cuda.is_available() else None,
39
+ attn_implementation="sdpa" # Use scaled dot product attention for efficiency
40
+ )
41
+
42
+ if not torch.cuda.is_available():
43
+ self.model = self.model.to(self.device)
44
+
45
+ self.model.eval()
46
+
47
+ # Default config
48
+ self.default_num_frames = 64 # V-JEPA 2 is trained with 64 frames
49
+
50
+ def _decode_video(self, video_data: Any) -> torch.Tensor:
51
+ """
52
+ Decode video from various input formats.
53
+
54
+ Supports:
55
+ - Base64 encoded video bytes
56
+ - URL to video file
57
+ - Raw bytes
58
+ """
59
+ from torchcodec.decoders import VideoDecoder
60
+
61
+ # Handle base64 encoded video
62
+ if isinstance(video_data, str):
63
+ if video_data.startswith(('http://', 'https://')):
64
+ # URL - torchcodec can handle URLs directly
65
+ vr = VideoDecoder(video_data)
66
+ elif video_data.startswith('data:'):
67
+ # Data URL format
68
+ header, encoded = video_data.split(',', 1)
69
+ video_bytes = base64.b64decode(encoded)
70
+ # Write to temp file for torchcodec
71
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
72
+ f.write(video_bytes)
73
+ temp_path = f.name
74
+ vr = VideoDecoder(temp_path)
75
+ os.unlink(temp_path)
76
+ else:
77
+ # Assume base64 encoded
78
+ video_bytes = base64.b64decode(video_data)
79
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
80
+ f.write(video_bytes)
81
+ temp_path = f.name
82
+ vr = VideoDecoder(temp_path)
83
+ os.unlink(temp_path)
84
+ elif isinstance(video_data, bytes):
85
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
86
+ f.write(video_data)
87
+ temp_path = f.name
88
+ vr = VideoDecoder(temp_path)
89
+ os.unlink(temp_path)
90
+ else:
91
+ raise ValueError(f"Unsupported video input type: {type(video_data)}")
92
+
93
+ return vr
94
+
95
+ def _sample_frames(
96
+ self,
97
+ video_decoder,
98
+ num_frames: int = 64,
99
+ sampling_strategy: str = "uniform"
100
+ ) -> torch.Tensor:
101
+ """
102
+ Sample frames from video decoder.
103
+
104
+ Args:
105
+ video_decoder: torchcodec VideoDecoder instance
106
+ num_frames: Number of frames to sample
107
+ sampling_strategy: "uniform" or "random"
108
+ """
109
+ # Get video metadata
110
+ metadata = video_decoder.metadata
111
+ total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000
112
+
113
+ if sampling_strategy == "uniform":
114
+ # Uniformly sample frames across the video
115
+ if total_frames <= num_frames:
116
+ frame_idx = np.arange(total_frames)
117
+ else:
118
+ frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
119
+ elif sampling_strategy == "random":
120
+ frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False))
121
+ else:
122
+ # Default to sequential from start
123
+ frame_idx = np.arange(min(num_frames, total_frames))
124
+
125
+ # Get frames: returns T x C x H x W
126
+ frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data
127
+
128
+ return frames
129
+
130
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
131
+ """
132
+ Process video and extract V-JEPA 2 features.
133
+
134
+ Expected input format:
135
+ {
136
+ "inputs": <base64_video_string or video_url>,
137
+ "parameters": {
138
+ "num_frames": 64, # Optional: number of frames to sample
139
+ "sampling_strategy": "uniform", # Optional: "uniform" or "random"
140
+ "return_predictor": true, # Optional: also return predictor features
141
+ "pooling": "mean" # Optional: "mean", "cls", or "none"
142
+ }
143
+ }
144
+
145
+ Returns:
146
+ {
147
+ "encoder_features": [...], # Encoder output features
148
+ "predictor_features": [...], # Optional predictor features
149
+ "feature_shape": [T, D], # Shape of features
150
+ }
151
+ """
152
+ # Extract inputs
153
+ inputs = data.get("inputs")
154
+ if inputs is None:
155
+ inputs = data.get("video")
156
+ if inputs is None:
157
+ raise ValueError("No video input provided. Use 'inputs' or 'video' key.")
158
+
159
+ # Extract parameters
160
+ params = data.get("parameters", {})
161
+ num_frames = params.get("num_frames", self.default_num_frames)
162
+ sampling_strategy = params.get("sampling_strategy", "uniform")
163
+ return_predictor = params.get("return_predictor", False)
164
+ pooling = params.get("pooling", "mean")
165
+
166
+ try:
167
+ # Decode and sample video
168
+ video_decoder = self._decode_video(inputs)
169
+ frames = self._sample_frames(video_decoder, num_frames, sampling_strategy)
170
+
171
+ # Process through V-JEPA 2 processor
172
+ processed = self.processor(frames, return_tensors="pt")
173
+ processed = {k: v.to(self.model.device) for k, v in processed.items()}
174
+
175
+ # Run inference
176
+ with torch.no_grad():
177
+ outputs = self.model(**processed)
178
+
179
+ # Extract encoder features
180
+ encoder_features = outputs.last_hidden_state # [batch, seq, hidden]
181
+
182
+ # Apply pooling
183
+ if pooling == "mean":
184
+ encoder_pooled = encoder_features.mean(dim=1) # [batch, hidden]
185
+ elif pooling == "cls":
186
+ encoder_pooled = encoder_features[:, 0, :] # [batch, hidden]
187
+ else:
188
+ encoder_pooled = encoder_features # [batch, seq, hidden]
189
+
190
+ result = {
191
+ "encoder_features": encoder_pooled.cpu().numpy().tolist(),
192
+ "feature_shape": list(encoder_pooled.shape),
193
+ }
194
+
195
+ # Optionally include predictor features
196
+ if return_predictor and hasattr(outputs, 'predictor_output'):
197
+ predictor_features = outputs.predictor_output.last_hidden_state
198
+ if pooling == "mean":
199
+ predictor_pooled = predictor_features.mean(dim=1)
200
+ elif pooling == "cls":
201
+ predictor_pooled = predictor_features[:, 0, :]
202
+ else:
203
+ predictor_pooled = predictor_features
204
+ result["predictor_features"] = predictor_pooled.cpu().numpy().tolist()
205
+ result["predictor_shape"] = list(predictor_pooled.shape)
206
+
207
+ return result
208
+
209
+ except Exception as e:
210
+ return {"error": str(e), "error_type": type(e).__name__}
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # V-JEPA 2 Inference Endpoint Requirements
2
+ # Note: transformers and torch are pre-installed in HF Inference containers
3
+
4
+ # For latest V-JEPA 2 support (may need bleeding edge)
5
+ transformers>=4.45.0
6
+ torch>=2.0.0
7
+
8
+ # Video decoding
9
+ torchcodec>=0.1.0
10
+
11
+ # Standard deps (usually pre-installed)
12
+ numpy>=1.24.0
13
+ einops>=0.7.0
14
+ timm>=0.9.0
15
+
16
+ # For efficient attention
17
+ accelerate>=0.25.0