Upload folder using huggingface_hub
Browse files
src/embeddings/audio_embedder.py
CHANGED
|
@@ -25,21 +25,28 @@ class AudioEmbedder:
|
|
| 25 |
self.model.to(self.device)
|
| 26 |
self.model.eval()
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
"""
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
"""
|
| 34 |
-
if
|
| 35 |
-
|
|
|
|
| 36 |
proj = getattr(self.model, projection, None)
|
| 37 |
if proj is not None:
|
| 38 |
pooled = proj(pooled)
|
| 39 |
-
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
@torch.no_grad()
|
| 45 |
def embed(self, audio_path: str) -> np.ndarray:
|
|
@@ -52,7 +59,7 @@ class AudioEmbedder:
|
|
| 52 |
).to(self.device)
|
| 53 |
|
| 54 |
outputs = self.model.get_audio_features(**inputs)
|
| 55 |
-
emb = self.
|
| 56 |
return emb.cpu().numpy().astype("float32")
|
| 57 |
|
| 58 |
@torch.no_grad()
|
|
@@ -64,5 +71,5 @@ class AudioEmbedder:
|
|
| 64 |
padding=True,
|
| 65 |
).to(self.device)
|
| 66 |
feats = self.model.get_text_features(**inputs)
|
| 67 |
-
feats = self.
|
| 68 |
return feats.cpu().numpy().astype("float32")
|
|
|
|
| 25 |
self.model.to(self.device)
|
| 26 |
self.model.eval()
|
| 27 |
|
| 28 |
+
def _extract_features(self, output, projection: str) -> torch.Tensor:
|
| 29 |
+
"""Extract 1-D projected embedding (512-d) from model output.
|
| 30 |
|
| 31 |
+
Handles both raw tensors and BaseModelOutputWithPooling objects
|
| 32 |
+
across different transformers versions.
|
| 33 |
"""
|
| 34 |
+
if not isinstance(output, torch.Tensor):
|
| 35 |
+
# BaseModelOutputWithPooling — extract pooled features and project
|
| 36 |
+
pooled = output.pooler_output
|
| 37 |
proj = getattr(self.model, projection, None)
|
| 38 |
if proj is not None:
|
| 39 |
pooled = proj(pooled)
|
| 40 |
+
output = pooled
|
| 41 |
+
if output.dim() == 3:
|
| 42 |
+
pooled = output[:, 0, :]
|
| 43 |
+
proj = getattr(self.model, projection, None)
|
| 44 |
+
if proj is not None:
|
| 45 |
+
pooled = proj(pooled)
|
| 46 |
+
output = pooled
|
| 47 |
+
if output.dim() == 2:
|
| 48 |
+
output = output[0]
|
| 49 |
+
return output
|
| 50 |
|
| 51 |
@torch.no_grad()
|
| 52 |
def embed(self, audio_path: str) -> np.ndarray:
|
|
|
|
| 59 |
).to(self.device)
|
| 60 |
|
| 61 |
outputs = self.model.get_audio_features(**inputs)
|
| 62 |
+
emb = self._extract_features(outputs, "audio_projection")
|
| 63 |
return emb.cpu().numpy().astype("float32")
|
| 64 |
|
| 65 |
@torch.no_grad()
|
|
|
|
| 71 |
padding=True,
|
| 72 |
).to(self.device)
|
| 73 |
feats = self.model.get_text_features(**inputs)
|
| 74 |
+
feats = self._extract_features(feats, "text_projection")
|
| 75 |
return feats.cpu().numpy().astype("float32")
|
src/embeddings/image_embedder.py
CHANGED
|
@@ -26,7 +26,14 @@ class ImageEmbedder:
|
|
| 26 |
image = Image.open(image_path).convert("RGB")
|
| 27 |
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
| 28 |
feats = self.model.get_image_features(**inputs)
|
| 29 |
-
# Handle different transformers versions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if feats.dim() == 3:
|
| 31 |
pooled = feats[:, 0, :]
|
| 32 |
proj = getattr(self.model, "visual_projection", None)
|
|
|
|
| 26 |
image = Image.open(image_path).convert("RGB")
|
| 27 |
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
| 28 |
feats = self.model.get_image_features(**inputs)
|
| 29 |
+
# Handle different transformers versions
|
| 30 |
+
if not isinstance(feats, torch.Tensor):
|
| 31 |
+
# BaseModelOutputWithPooling — extract and project
|
| 32 |
+
pooled = feats.pooler_output
|
| 33 |
+
proj = getattr(self.model, "visual_projection", None)
|
| 34 |
+
if proj is not None:
|
| 35 |
+
pooled = proj(pooled)
|
| 36 |
+
feats = pooled
|
| 37 |
if feats.dim() == 3:
|
| 38 |
pooled = feats[:, 0, :]
|
| 39 |
proj = getattr(self.model, "visual_projection", None)
|
src/embeddings/text_embedder.py
CHANGED
|
@@ -29,7 +29,14 @@ class TextEmbedder:
|
|
| 29 |
truncation=True,
|
| 30 |
).to(self.device)
|
| 31 |
feats = self.model.get_text_features(**inputs)
|
| 32 |
-
# Handle different transformers versions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if feats.dim() == 3:
|
| 34 |
pooled = feats[:, 0, :]
|
| 35 |
proj = getattr(self.model, "text_projection", None)
|
|
|
|
| 29 |
truncation=True,
|
| 30 |
).to(self.device)
|
| 31 |
feats = self.model.get_text_features(**inputs)
|
| 32 |
+
# Handle different transformers versions
|
| 33 |
+
if not isinstance(feats, torch.Tensor):
|
| 34 |
+
# BaseModelOutputWithPooling — extract and project
|
| 35 |
+
pooled = feats.pooler_output
|
| 36 |
+
proj = getattr(self.model, "text_projection", None)
|
| 37 |
+
if proj is not None:
|
| 38 |
+
pooled = proj(pooled)
|
| 39 |
+
feats = pooled
|
| 40 |
if feats.dim() == 3:
|
| 41 |
pooled = feats[:, 0, :]
|
| 42 |
proj = getattr(self.model, "text_projection", None)
|