pratik-250620 commited on
Commit
d7c075c
·
verified ·
1 Parent(s): c98d24c

Upload folder using huggingface_hub

Browse files
src/embeddings/audio_embedder.py CHANGED
@@ -25,21 +25,28 @@ class AudioEmbedder:
25
  self.model.to(self.device)
26
  self.model.eval()
27
 
28
- def _squeeze_features(self, feats: torch.Tensor, projection: str) -> torch.Tensor:
29
- """Ensure features are 1-D projected embeddings (512-d).
30
 
31
- Some transformers versions return raw hidden states (batch, seq, hidden)
32
- instead of projected features (batch, proj_dim). Detect and fix.
33
  """
34
- if feats.dim() == 3:
35
- pooled = feats[:, 0, :] # CLS token
 
36
  proj = getattr(self.model, projection, None)
37
  if proj is not None:
38
  pooled = proj(pooled)
39
- feats = pooled
40
- if feats.dim() == 2:
41
- feats = feats[0]
42
- return feats
 
 
 
 
 
 
43
 
44
  @torch.no_grad()
45
  def embed(self, audio_path: str) -> np.ndarray:
@@ -52,7 +59,7 @@ class AudioEmbedder:
52
  ).to(self.device)
53
 
54
  outputs = self.model.get_audio_features(**inputs)
55
- emb = self._squeeze_features(outputs, "audio_projection")
56
  return emb.cpu().numpy().astype("float32")
57
 
58
  @torch.no_grad()
@@ -64,5 +71,5 @@ class AudioEmbedder:
64
  padding=True,
65
  ).to(self.device)
66
  feats = self.model.get_text_features(**inputs)
67
- feats = self._squeeze_features(feats, "text_projection")
68
  return feats.cpu().numpy().astype("float32")
 
25
  self.model.to(self.device)
26
  self.model.eval()
27
 
28
+ def _extract_features(self, output, projection: str) -> torch.Tensor:
29
+ """Extract 1-D projected embedding (512-d) from model output.
30
 
31
+ Handles both raw tensors and BaseModelOutputWithPooling objects
32
+ across different transformers versions.
33
  """
34
+ if not isinstance(output, torch.Tensor):
35
+ # BaseModelOutputWithPooling extract pooled features and project
36
+ pooled = output.pooler_output
37
  proj = getattr(self.model, projection, None)
38
  if proj is not None:
39
  pooled = proj(pooled)
40
+ output = pooled
41
+ if output.dim() == 3:
42
+ pooled = output[:, 0, :]
43
+ proj = getattr(self.model, projection, None)
44
+ if proj is not None:
45
+ pooled = proj(pooled)
46
+ output = pooled
47
+ if output.dim() == 2:
48
+ output = output[0]
49
+ return output
50
 
51
  @torch.no_grad()
52
  def embed(self, audio_path: str) -> np.ndarray:
 
59
  ).to(self.device)
60
 
61
  outputs = self.model.get_audio_features(**inputs)
62
+ emb = self._extract_features(outputs, "audio_projection")
63
  return emb.cpu().numpy().astype("float32")
64
 
65
  @torch.no_grad()
 
71
  padding=True,
72
  ).to(self.device)
73
  feats = self.model.get_text_features(**inputs)
74
+ feats = self._extract_features(feats, "text_projection")
75
  return feats.cpu().numpy().astype("float32")
src/embeddings/image_embedder.py CHANGED
@@ -26,7 +26,14 @@ class ImageEmbedder:
26
  image = Image.open(image_path).convert("RGB")
27
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
28
  feats = self.model.get_image_features(**inputs)
29
- # Handle different transformers versions (some return 3-D hidden states)
 
 
 
 
 
 
 
30
  if feats.dim() == 3:
31
  pooled = feats[:, 0, :]
32
  proj = getattr(self.model, "visual_projection", None)
 
26
  image = Image.open(image_path).convert("RGB")
27
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
28
  feats = self.model.get_image_features(**inputs)
29
+ # Handle different transformers versions
30
+ if not isinstance(feats, torch.Tensor):
31
+ # BaseModelOutputWithPooling — extract and project
32
+ pooled = feats.pooler_output
33
+ proj = getattr(self.model, "visual_projection", None)
34
+ if proj is not None:
35
+ pooled = proj(pooled)
36
+ feats = pooled
37
  if feats.dim() == 3:
38
  pooled = feats[:, 0, :]
39
  proj = getattr(self.model, "visual_projection", None)
src/embeddings/text_embedder.py CHANGED
@@ -29,7 +29,14 @@ class TextEmbedder:
29
  truncation=True,
30
  ).to(self.device)
31
  feats = self.model.get_text_features(**inputs)
32
- # Handle different transformers versions (some return 3-D hidden states)
 
 
 
 
 
 
 
33
  if feats.dim() == 3:
34
  pooled = feats[:, 0, :]
35
  proj = getattr(self.model, "text_projection", None)
 
29
  truncation=True,
30
  ).to(self.device)
31
  feats = self.model.get_text_features(**inputs)
32
+ # Handle different transformers versions
33
+ if not isinstance(feats, torch.Tensor):
34
+ # BaseModelOutputWithPooling — extract and project
35
+ pooled = feats.pooler_output
36
+ proj = getattr(self.model, "text_projection", None)
37
+ if proj is not None:
38
+ pooled = proj(pooled)
39
+ feats = pooled
40
  if feats.dim() == 3:
41
  pooled = feats[:, 0, :]
42
  proj = getattr(self.model, "text_projection", None)