import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from transformers import PreTrainedModel, RobertaModel, RobertaTokenizer from .configuration_finelap import FineLAPConfig from .modeling_eat import EATModel import os class FineLAPModel(PreTrainedModel): config_class = FineLAPConfig def __init__(self, config: FineLAPConfig): super().__init__(config) self.config = config self.audio_encoder = EATModel(config.audio_config) self.audio_width = getattr(config.audio_config, 'hidden_size', 768) self.text_encoder = RobertaModel.from_pretrained(config.text_encoder_name, add_pooling_layer=False) self.text_width = self.text_encoder.config.hidden_size self.tokenizer = RobertaTokenizer.from_pretrained(config.text_encoder_name) self.embed_size = config.embed_size for param in ['temp_global', 'b_global', 'temp_local', 'b_local']: val = getattr(config, param, None) if val is not None: self.register_parameter(param, nn.Parameter(torch.ones([]) * val)) self.global_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size)) self.global_text_proj = nn.Sequential(nn.Linear(self.text_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size)) self.local_audio_proj_type = config.local_audio_proj_type if self.local_audio_proj_type == "rnn": self.local_audio_proj = nn.GRU(input_size=self.audio_width, hidden_size=int(self.embed_size / 2), num_layers=2, batch_first=True, bidirectional=True) elif self.local_audio_proj_type == "transformer": l = nn.TransformerEncoderLayer(d_model=self.embed_size, nhead=8, dim_feedforward=self.embed_size * 4, dropout=0.1, activation='relu', batch_first=True) self.local_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.TransformerEncoder(l, num_layers=2)) elif self.local_audio_proj_type == "linear": self.local_audio_proj = nn.Sequential(nn.Linear(self.audio_width, self.embed_size), nn.ReLU(), nn.Linear(self.embed_size, self.embed_size)) self.post_init() def load_audio(self, audio_path, device=None): device = device or self.device if isinstance(audio_path, (str, os.PathLike)): audio_path = [str(audio_path)] else: audio_path = [str(p) for p in audio_path] # list/tuple/iterable of paths mels = [] target_len = 1024 for audio_path in audio_path: wav, sr = torchaudio.load(audio_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000) wav = wav.squeeze(0) wav = wav - wav.mean() mel = torchaudio.compliance.kaldi.fbank( wav.unsqueeze(0), htk_compat=True, sample_frequency=16000, use_energy=False, window_type="hanning", num_mel_bins=128, dither=0.0, frame_shift=10, ) # [T, 128] if mel.shape[0] < target_len: mel = F.pad(mel, (0, 0, 0, target_len - mel.shape[0])) else: mel = mel[:target_len, :] mel = (mel - (-4.268)) / (4.569 * 2) # [T, 128] mels.append(mel) mel_batch = torch.stack(mels, dim=0).unsqueeze(1).to(device) # [B, 1, 1024, 128] return mel_batch def encode_audio(self, audio_path): audio_mel = self.load_audio(audio_path) outputs = self.audio_encoder.extract_features(audio_mel) raw = outputs['x'] if isinstance(outputs, dict) else outputs B, T, D = raw[:, 1:, :].shape ds = 8 patches = raw[:, 1:, :].reshape(B, T // ds, ds, D).mean(dim=2) return torch.cat([raw[:, 0:1, :], patches], dim=1) def get_global_text_embeds(self, text_labels, device=None): device = device or self.device t_in = self.tokenizer(text_labels, padding=True, truncation=True, return_tensors="pt").to(device) feat = self.text_encoder(input_ids=t_in["input_ids"], attention_mask=t_in["attention_mask"]).last_hidden_state return F.normalize(self.global_text_proj(feat[:, 0, :]), dim=-1) def get_global_audio_embeds(self, audio_path): audio_feats = self.encode_audio(audio_path) if self.config.unify_audio_proj: audio_embeds = self.local_audio_proj(audio_feats) if self.config.local_audio_proj_type == "rnn": audio_embeds = audio_embeds[0] return F.normalize(audio_embeds[:, 0, :], dim=-1) else: audio_cls_feat = audio_feats[:, 0, :] return F.normalize(self.global_audio_proj(audio_cls_feat), dim=-1) def get_dense_audio_embeds(self, audio_path): patches = self.encode_audio(audio_path)[:, 1:, :] out = self.local_audio_proj(patches) embeds = out[0] if self.local_audio_proj_type == "rnn" else out return F.normalize(embeds, dim=-1) if self.config.normalize_dense_audio_embeds else embeds @torch.no_grad() def get_frame_level_score(self, audio_path, text_labels, device=None): device = device or self.device self.to(device) self.eval() dense_audio = self.get_dense_audio_embeds(audio_path) # (B, T, D) text_embeds = self.get_global_text_embeds(text_labels, device).unsqueeze(0) # (1, N, D) sim = torch.matmul(text_embeds, dense_audio.transpose(-1, -2)) if hasattr(self, "temp_local"): sim = sim / self.temp_local if hasattr(self, "b_local"): sim = sim + self.b_local return F.sigmoid(sim) @torch.no_grad() def get_clip_level_score(self, audio_path, text_labels, device=None): device = device or self.device self.to(device) self.eval() global_audio = self.get_global_audio_embeds(audio_path) global_text = self.get_global_text_embeds(text_labels, device) logits = torch.matmul(global_text, global_audio.transpose(-1, -2)) # return logits if hasattr(self, "temp_global"): logits = logits / self.temp_global if hasattr(self, "b_global"): logits = logits + self.b_global return F.sigmoid(logits).squeeze(-1) @torch.no_grad() def plot_frame_level_score(self, audio_path, text_labels, output_path="similarity_plot.png", device=None): import matplotlib.pyplot as plt import numpy as np scores = self.get_frame_level_score(audio_path, text_labels, device).squeeze() sim_matrix_np = scores.cpu().numpy() fig, ax = plt.subplots(figsize=(14, 8)) im = ax.imshow(sim_matrix_np, aspect='auto', cmap='viridis', interpolation='nearest') ax.set_xlabel('Time Frames', fontsize=12) ax.set_ylabel('Labels', fontsize=12) ax.set_title('Frame-level Audio-Text Similarity', fontsize=14) ax.set_yticks(range(len(text_labels))) ax.set_yticklabels(text_labels) cbar = plt.colorbar(im, ax=ax) cbar.set_label('Similarity Score', rotation=270, labelpad=20) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') print(f"Figure saved to {output_path}") plt.close() def forward(self, audio_path=None, text_labels=None): res = {} if audio_path is not None: res["global_audio_embeds"] = self.get_global_audio_embeds(audio_path) if not self.config.unify_audio_proj else None res["dense_audio_embeds"] = self.get_dense_audio_embeds(audio_path) if text_labels is not None: res["global_text_embeds"] = self.get_global_text_embeds(text_labels) return res