AndreasXi commited on
Commit
0566826
·
verified ·
1 Parent(s): 9dace28

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. configuration_finelap.py +0 -1
  2. modeling_finelap.py +6 -6
configuration_finelap.py CHANGED
@@ -28,7 +28,6 @@ class FineLAPConfig(PretrainedConfig):
28
  self.unify_audio_proj = unify_audio_proj
29
  self.text_encoder_name = text_encoder_name
30
 
31
- # 👈 关键修改 2:如果读进来的是字典,把它重新包装成 EATConfig 对象
32
  if isinstance(audio_config, dict):
33
  self.audio_config = EATConfig(**audio_config)
34
  elif isinstance(audio_config, EATConfig):
 
28
  self.unify_audio_proj = unify_audio_proj
29
  self.text_encoder_name = text_encoder_name
30
 
 
31
  if isinstance(audio_config, dict):
32
  self.audio_config = EATConfig(**audio_config)
33
  elif isinstance(audio_config, EATConfig):
modeling_finelap.py CHANGED
@@ -121,12 +121,12 @@ class FineLAPModel(PreTrainedModel):
121
  global_text = self.get_global_text_embeds(text_labels, device)
122
 
123
  logits = torch.matmul(global_text, global_audio.transpose(-1, -2))
124
- return logits
125
- # if hasattr(self, "temp_global"):
126
- # logits = logits / self.temp_global
127
- # if hasattr(self, "b_global"):
128
- # logits = logits + self.b_global
129
- # return torch.sigmoid(logits).squeeze(-1)
130
 
131
  @torch.no_grad()
132
  def plot_frame_level_score(self, audio_path, text_labels, output_path="similarity_plot.png", device=None):
 
121
  global_text = self.get_global_text_embeds(text_labels, device)
122
 
123
  logits = torch.matmul(global_text, global_audio.transpose(-1, -2))
124
+ # return logits
125
+ if hasattr(self, "temp_global"):
126
+ logits = logits / self.temp_global
127
+ if hasattr(self, "b_global"):
128
+ logits = logits + self.b_global
129
+ return F.sigmoid(logits).squeeze(-1)
130
 
131
  @torch.no_grad()
132
  def plot_frame_level_score(self, audio_path, text_labels, output_path="similarity_plot.png", device=None):