AndreasXi commited on
Commit
6571c49
·
verified ·
1 Parent(s): 6daf432

Upload folder using huggingface_hub

Browse files
__pycache__/configuration_eat.cpython-39.pyc ADDED
Binary file (1.45 kB). View file
 
__pycache__/configuration_finelap.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
__pycache__/eat_model.cpython-39.pyc ADDED
Binary file (3.59 kB). View file
 
__pycache__/eat_model_core.cpython-39.pyc ADDED
Binary file (6.07 kB). View file
 
__pycache__/modeling_eat.cpython-39.pyc ADDED
Binary file (1.04 kB). View file
 
__pycache__/modeling_finelap.cpython-39.pyc ADDED
Binary file (4.52 kB). View file
 
modeling_finelap.py CHANGED
@@ -1,9 +1,8 @@
1
- # modeling_finelap.py
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
  from transformers import PreTrainedModel, RobertaModel, RobertaTokenizer
6
-
7
  from .configuration_finelap import FineLAPConfig
8
  from .modeling_eat import EATModel
9
 
@@ -13,163 +12,149 @@ class FineLAPModel(PreTrainedModel):
13
  def __init__(self, config: FineLAPConfig):
14
  super().__init__(config)
15
  self.config = config
16
-
17
  self.audio_encoder = EATModel(config.audio_config)
18
  self.audio_width = getattr(config.audio_config, 'hidden_size', 768)
19
-
20
- self.text_encoder = RobertaModel.from_pretrained(
21
- config.text_encoder_name,
22
- add_pooling_layer=False,
23
- )
24
 
 
25
  self.text_width = self.text_encoder.config.hidden_size
 
 
26
  self.embed_size = config.embed_size
27
 
28
- if config.temp_global != 0:
29
- self.temp_global = nn.Parameter(torch.ones([]) * config.temp_global)
30
- if config.b_global != 0:
31
- self.b_global = nn.Parameter(torch.ones([]) * config.b_global)
32
- if config.temp_local != 0:
33
- self.temp_local = nn.Parameter(torch.ones([]) * config.temp_local)
34
- if config.b_local != 0:
35
- self.b_local = nn.Parameter(torch.ones([]) * config.b_local)
36
-
37
- self.global_audio_proj = nn.Sequential(
38
- nn.Linear(self.audio_width, self.embed_size),
39
- nn.ReLU(),
40
- nn.Linear(self.embed_size, self.embed_size),
41
- )
42
- self.global_text_proj = nn.Sequential(
43
- nn.Linear(self.text_width, self.embed_size),
44
- nn.ReLU(),
45
- nn.Linear(self.embed_size, self.embed_size),
46
- )
47
 
48
- # 5. Local Audio Projection Layer
49
  self.local_audio_proj_type = config.local_audio_proj_type
50
  if self.local_audio_proj_type == "rnn":
51
- self.local_audio_proj = nn.GRU(
52
- input_size=self.audio_width,
53
- hidden_size=int(self.embed_size / 2),
54
- num_layers=2,
55
- batch_first=True,
56
- bidirectional=True
57
- )
58
- elif self.local_audio_proj_type == "linear":
59
- self.local_audio_proj = nn.Sequential(
60
- nn.Linear(self.audio_width, self.embed_size),
61
- nn.ReLU(),
62
- nn.Linear(self.embed_size, self.embed_size)
63
- )
64
  elif self.local_audio_proj_type == "transformer":
65
- encoder_layer = nn.TransformerEncoderLayer(
66
- d_model=self.embed_size,
67
- nhead=8,
68
- dim_feedforward=self.embed_size * 4,
69
- dropout=0.1,
70
- activation='relu',
71
- batch_first=True
72
- )
73
- transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
74
- self.local_audio_proj = nn.Sequential(
75
- nn.Linear(self.audio_width, self.embed_size),
76
- transformer_encoder
77
- )
78
- elif self.local_audio_proj_type == "transformer_linearlast":
79
- encoder_layer = nn.TransformerEncoderLayer(
80
- d_model=self.audio_width,
81
- nhead=8,
82
- dim_feedforward=self.audio_width * 4,
83
- dropout=0.1,
84
- activation='relu',
85
- batch_first=True
86
- )
87
- transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
88
- self.local_audio_proj = nn.Sequential(
89
- transformer_encoder,
90
- nn.Linear(self.audio_width, self.embed_size),
91
- )
92
- else:
93
- raise ValueError(f"Invalid local audio proj type: {self.local_audio_proj_type}")
94
-
95
  self.post_init()
96
 
97
-
98
- def encode_audio(self, audio_mel):
99
-
100
- outputs = self.audio_encoder.extract_features(audio_mel)
101
- audio_encoded_raw = outputs['x'] if isinstance(outputs, dict) else outputs
102
-
103
- audio_cls = audio_encoded_raw[:, 0:1, :]
104
- audio_patches = audio_encoded_raw[:, 1:, :]
 
105
 
106
- B, T, D = audio_patches.shape
107
- ds_factor = 8
108
- audio_patches_downsampled = audio_patches.reshape(
109
- B, T // ds_factor, ds_factor, D
110
- ).mean(dim=2)
111
-
112
- # [B, 1+T//8, D]
113
- audio_encoded = torch.cat([audio_cls, audio_patches_downsampled], dim=1)
114
- return audio_encoded
115
-
116
-
117
- def encode_text(self, input_ids, attention_mask):
118
- outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
119
- return outputs.last_hidden_state
120
-
121
-
122
- def get_global_text_embeds(self, input_ids, attention_mask):
123
- text_feats = self.encode_text(input_ids, attention_mask)
124
- text_embeds = F.normalize(self.global_text_proj(text_feats[:, 0, :]), dim=-1)
125
- return text_embeds
126
-
127
 
128
- def get_global_audio_embeds(self, audio_mel):
129
- audio_feats = self.encode_audio(audio_mel)
130
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if self.config.unify_audio_proj:
132
  audio_embeds = self.local_audio_proj(audio_feats)
133
  if self.config.local_audio_proj_type == "rnn":
134
  audio_embeds = audio_embeds[0]
135
- global_audio_embeds = F.normalize(audio_embeds[:, 0, :], dim=-1)
136
- return global_audio_embeds
137
  else:
138
  audio_cls_feat = audio_feats[:, 0, :]
139
- audio_embeds = F.normalize(self.global_audio_proj(audio_cls_feat), dim=-1)
140
- return audio_embeds
141
-
142
-
143
- def get_dense_audio_embeds(self, audio_mel):
144
- audio_feats = self.encode_audio(audio_mel)
145
- audio_patches = audio_feats[:, 1:, :]
146
-
147
- audio_embeds = self.local_audio_proj(audio_patches)
148
- if self.config.local_audio_proj_type == "rnn":
149
- audio_embeds = audio_embeds[0]
150
-
151
- if self.config.normalize_dense_audio_embeds:
152
- audio_embeds = F.normalize(audio_embeds, dim=-1)
153
- return audio_embeds
154
-
155
-
156
- def forward(self, audio_mel=None, input_ids=None, attention_mask=None, return_dict=True):
157
- global_audio_embeds = None
158
- dense_audio_embeds = None
159
- global_text_embeds = None
160
-
161
- if audio_mel is not None:
162
- global_audio_embeds = self.get_global_audio_embeds(audio_mel)
163
- dense_audio_embeds = self.get_dense_audio_embeds(audio_mel)
164
-
165
- if input_ids is not None:
166
- global_text_embeds = self.get_global_text_embeds(input_ids, attention_mask)
167
-
168
- if not return_dict:
169
- return (global_audio_embeds, dense_audio_embeds, global_text_embeds)
170
-
171
- return {
172
- "global_audio_embeds": global_audio_embeds,
173
- "dense_audio_embeds": dense_audio_embeds,
174
- "global_text_embeds": global_text_embeds
175
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ import torchaudio
5
  from transformers import PreTrainedModel, RobertaModel, RobertaTokenizer
 
6
  from .configuration_finelap import FineLAPConfig
7
  from .modeling_eat import EATModel
8
 
 
12
  def __init__(self, config: FineLAPConfig):
13
  super().__init__(config)
14
  self.config = config
 
15
  self.audio_encoder = EATModel(config.audio_config)
16
  self.audio_width = getattr(config.audio_config, 'hidden_size', 768)
 
 
 
 
 
17
 
18
+ self.text_encoder = RobertaModel.from_pretrained(config.text_encoder_name, add_pooling_layer=False)
19
  self.text_width = self.text_encoder.config.hidden_size
20
+ self.tokenizer = RobertaTokenizer.from_pretrained(config.text_encoder_name)
21
+
22
  self.embed_size = config.embed_size
23
 
24
+ for param in ['temp_global', 'b_global', 'temp_local', 'b_local']:
25
+ val = getattr(config, param, None)
26
+ if val is not None:
27
+ self.register_parameter(param, nn.Parameter(torch.ones([]) * val))
28
+
29
+ self.global_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size))
30
+ self.global_text_proj = nn.Sequential(nn.Linear(self.text_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size))
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
  self.local_audio_proj_type = config.local_audio_proj_type
33
  if self.local_audio_proj_type == "rnn":
34
+ self.local_audio_proj = nn.GRU(input_size=self.audio_width, hidden_size=int(self.embed_size / 2), num_layers=2, batch_first=True, bidirectional=True)
 
 
 
 
 
 
 
 
 
 
 
 
35
  elif self.local_audio_proj_type == "transformer":
36
+ l = nn.TransformerEncoderLayer(d_model=self.embed_size, nhead=8, dim_feedforward=self.embed_size * 4, dropout=0.1, activation='relu', batch_first=True)
37
+ self.local_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.TransformerEncoder(l, num_layers=2))
38
+ elif self.local_audio_proj_type == "linear":
39
+ self.local_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size))
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.post_init()
42
 
43
+ def load_audio(self, audio_path, device=None):
44
+ device = device or self.device
45
+ wav, sr = torchaudio.load(audio_path)
46
+ if wav.shape[0] > 1:
47
+ wav = wav.mean(dim=0, keepdim=True)
48
+ if sr != 16000:
49
+ wav = torchaudio.functional.resample(wav, sr, 16000)
50
+ wav = wav.squeeze(0)
51
+ wav = wav - wav.mean()
52
 
53
+ mel = torchaudio.compliance.kaldi.fbank(
54
+ wav.unsqueeze(0), htk_compat=True, sample_frequency=16000,
55
+ use_energy=False, window_type='hanning', num_mel_bins=128,
56
+ dither=0.0, frame_shift=10
57
+ )
58
+ target_len = 1024
59
+ if mel.shape[0] < target_len:
60
+ mel = F.pad(mel, (0, 0, 0, target_len - mel.shape[0]))
61
+ else:
62
+ mel = mel[:target_len, :]
63
+ mel = ((mel - (-4.268)) / (4.569 * 2)).unsqueeze(0).unsqueeze(0).to(device)
64
+ return mel
 
 
 
 
 
 
 
 
 
65
 
66
+ def encode_audio(self, audio_path):
67
+ audio_mel = self.load_audio(audio_path)
68
+ outputs = self.audio_encoder.extract_features(audio_mel)
69
+ raw = outputs['x'] if isinstance(outputs, dict) else outputs
70
+ B, T, D = raw[:, 1:, :].shape
71
+ ds = 8
72
+ patches = raw[:, 1:, :].reshape(B, T // ds, ds, D).mean(dim=2)
73
+ return torch.cat([raw[:, 0:1, :], patches], dim=1)
74
+
75
+ def get_global_text_embeds(self, text_labels, device=None):
76
+ device = device or self.device
77
+ t_in = self.tokenizer(text_labels, padding=True, truncation=True, return_tensors="pt").to(device)
78
+ feat = self.text_encoder(input_ids=t_in["input_ids"], attention_mask=t_in["attention_mask"]).last_hidden_state
79
+ return F.normalize(self.global_text_proj(feat[:, 0, :]), dim=-1)
80
+
81
+ def get_global_audio_embeds(self, audio_path):
82
+ audio_feats = self.encode_audio(audio_path)
83
  if self.config.unify_audio_proj:
84
  audio_embeds = self.local_audio_proj(audio_feats)
85
  if self.config.local_audio_proj_type == "rnn":
86
  audio_embeds = audio_embeds[0]
87
+ return F.normalize(audio_embeds[:, 0, :], dim=-1)
 
88
  else:
89
  audio_cls_feat = audio_feats[:, 0, :]
90
+ return F.normalize(self.global_audio_proj(audio_cls_feat), dim=-1)
91
+
92
+ def get_dense_audio_embeds(self, audio_path):
93
+ patches = self.encode_audio(audio_path)[:, 1:, :]
94
+ out = self.local_audio_proj(patches)
95
+ embeds = out[0] if self.local_audio_proj_type == "rnn" else out
96
+ return F.normalize(embeds, dim=-1) if self.config.normalize_dense_audio_embeds else embeds
97
+
98
+ @torch.no_grad()
99
+ def get_frame_level_score(self, audio_path, text_labels, device=None):
100
+ device = device or self.device
101
+ self.to(device)
102
+ self.eval()
103
+
104
+ dense_audio = self.get_dense_audio_embeds(audio_path).squeeze(0)
105
+ text_embeds = self.get_global_text_embeds(text_labels, device)
106
+
107
+ sim = torch.matmul(text_embeds, dense_audio.transpose(-1, -2))
108
+ if hasattr(self, "temp_local"):
109
+ sim = sim / self.temp_local
110
+ if hasattr(self, "b_local"):
111
+ sim = sim + self.b_local
112
+ return F.sigmoid(sim)
113
+
114
+ @torch.no_grad()
115
+ def get_clip_level_score(self, audio_path, text_labels, device=None):
116
+ device = device or self.device
117
+ self.to(device)
118
+ self.eval()
119
+
120
+ global_audio = self.get_global_audio_embeds(audio_path)
121
+ global_text = self.get_global_text_embeds(text_labels, device)
122
+
123
+ logits = torch.matmul(global_text, global_audio.transpose(-1, -2))
124
+ if hasattr(self, "temp_global"):
125
+ logits = logits / self.temp_global
126
+ if hasattr(self, "b_global"):
127
+ logits = logits + self.b_global
128
+ return torch.sigmoid(logits).squeeze(-1)
129
+
130
+ @torch.no_grad()
131
+ def plot_frame_level_score(self, audio_path, text_labels, output_path="similarity_plot.png", device=None):
132
+ import matplotlib.pyplot as plt
133
+ import numpy as np
134
+
135
+ scores = self.get_frame_level_score(audio_path, text_labels, device)
136
+ sim_matrix_np = scores.cpu().numpy()
137
+
138
+ fig, ax = plt.subplots(figsize=(14, 8))
139
+ im = ax.imshow(sim_matrix_np, aspect='auto', cmap='viridis', interpolation='nearest')
140
+ ax.set_xlabel('Time Frames', fontsize=12)
141
+ ax.set_ylabel('Labels', fontsize=12)
142
+ ax.set_title('Frame-level Audio-Text Similarity', fontsize=14)
143
+ ax.set_yticks(range(len(text_labels)))
144
+ ax.set_yticklabels(text_labels)
145
+
146
+ cbar = plt.colorbar(im, ax=ax)
147
+ cbar.set_label('Similarity Score', rotation=270, labelpad=20)
148
+
149
+ plt.tight_layout()
150
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
151
+ plt.close()
152
+
153
+ def forward(self, audio_path=None, text_labels=None):
154
+ res = {}
155
+ if audio_path is not None:
156
+ res["global_audio_embeds"] = self.get_global_audio_embeds(audio_path) if not self.config.unify_audio_proj else None
157
+ res["dense_audio_embeds"] = self.get_dense_audio_embeds(audio_path)
158
+ if text_labels is not None:
159
+ res["global_text_embeds"] = self.get_global_text_embeds(text_labels)
160
+ return res