pratik-250620 commited on
Commit
960dff6
·
verified ·
1 Parent(s): d7c075c

Upload folder using huggingface_hub

Browse files
src/embeddings/audio_embedder.py CHANGED
@@ -31,18 +31,22 @@ class AudioEmbedder:
31
  Handles both raw tensors and BaseModelOutputWithPooling objects
32
  across different transformers versions.
33
  """
 
34
  if not isinstance(output, torch.Tensor):
35
- # BaseModelOutputWithPooling — extract pooled features and project
36
  pooled = output.pooler_output
37
- proj = getattr(self.model, projection, None)
38
- if proj is not None:
39
- pooled = proj(pooled)
 
 
40
  output = pooled
41
  if output.dim() == 3:
42
  pooled = output[:, 0, :]
43
- proj = getattr(self.model, projection, None)
44
- if proj is not None:
45
- pooled = proj(pooled)
 
46
  output = pooled
47
  if output.dim() == 2:
48
  output = output[0]
 
31
  Handles both raw tensors and BaseModelOutputWithPooling objects
32
  across different transformers versions.
33
  """
34
+ target_dim = getattr(self.model.config, "projection_dim", 512)
35
  if not isinstance(output, torch.Tensor):
36
+ # BaseModelOutputWithPooling — extract pooled features
37
  pooled = output.pooler_output
38
+ # Only project if not already at target dim
39
+ if pooled.shape[-1] != target_dim:
40
+ proj = getattr(self.model, projection, None)
41
+ if proj is not None:
42
+ pooled = proj(pooled)
43
  output = pooled
44
  if output.dim() == 3:
45
  pooled = output[:, 0, :]
46
+ if pooled.shape[-1] != target_dim:
47
+ proj = getattr(self.model, projection, None)
48
+ if proj is not None:
49
+ pooled = proj(pooled)
50
  output = pooled
51
  if output.dim() == 2:
52
  output = output[0]
src/embeddings/image_embedder.py CHANGED
@@ -27,18 +27,20 @@ class ImageEmbedder:
27
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
28
  feats = self.model.get_image_features(**inputs)
29
  # Handle different transformers versions
 
30
  if not isinstance(feats, torch.Tensor):
31
- # BaseModelOutputWithPooling — extract and project
32
  pooled = feats.pooler_output
33
- proj = getattr(self.model, "visual_projection", None)
34
- if proj is not None:
35
- pooled = proj(pooled)
 
36
  feats = pooled
37
  if feats.dim() == 3:
38
  pooled = feats[:, 0, :]
39
- proj = getattr(self.model, "visual_projection", None)
40
- if proj is not None:
41
- pooled = proj(pooled)
 
42
  feats = pooled
43
  if feats.dim() == 2:
44
  feats = feats[0]
 
27
  inputs = self.processor(images=image, return_tensors="pt").to(self.device)
28
  feats = self.model.get_image_features(**inputs)
29
  # Handle different transformers versions
30
+ target_dim = getattr(self.model.config, "projection_dim", 512)
31
  if not isinstance(feats, torch.Tensor):
 
32
  pooled = feats.pooler_output
33
+ if pooled.shape[-1] != target_dim:
34
+ proj = getattr(self.model, "visual_projection", None)
35
+ if proj is not None:
36
+ pooled = proj(pooled)
37
  feats = pooled
38
  if feats.dim() == 3:
39
  pooled = feats[:, 0, :]
40
+ if pooled.shape[-1] != target_dim:
41
+ proj = getattr(self.model, "visual_projection", None)
42
+ if proj is not None:
43
+ pooled = proj(pooled)
44
  feats = pooled
45
  if feats.dim() == 2:
46
  feats = feats[0]
src/embeddings/text_embedder.py CHANGED
@@ -30,18 +30,20 @@ class TextEmbedder:
30
  ).to(self.device)
31
  feats = self.model.get_text_features(**inputs)
32
  # Handle different transformers versions
 
33
  if not isinstance(feats, torch.Tensor):
34
- # BaseModelOutputWithPooling — extract and project
35
  pooled = feats.pooler_output
36
- proj = getattr(self.model, "text_projection", None)
37
- if proj is not None:
38
- pooled = proj(pooled)
 
39
  feats = pooled
40
  if feats.dim() == 3:
41
  pooled = feats[:, 0, :]
42
- proj = getattr(self.model, "text_projection", None)
43
- if proj is not None:
44
- pooled = proj(pooled)
 
45
  feats = pooled
46
  if feats.dim() == 2:
47
  feats = feats[0]
 
30
  ).to(self.device)
31
  feats = self.model.get_text_features(**inputs)
32
  # Handle different transformers versions
33
+ target_dim = getattr(self.model.config, "projection_dim", 512)
34
  if not isinstance(feats, torch.Tensor):
 
35
  pooled = feats.pooler_output
36
+ if pooled.shape[-1] != target_dim:
37
+ proj = getattr(self.model, "text_projection", None)
38
+ if proj is not None:
39
+ pooled = proj(pooled)
40
  feats = pooled
41
  if feats.dim() == 3:
42
  pooled = feats[:, 0, :]
43
+ if pooled.shape[-1] != target_dim:
44
+ proj = getattr(self.model, "text_projection", None)
45
+ if proj is not None:
46
+ pooled = proj(pooled)
47
  feats = pooled
48
  if feats.dim() == 2:
49
  feats = feats[0]