|
|
from typing import Any |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import random |
|
|
from utils.audiotime_event_merge import replace_event_synonyms |
|
|
|
|
|
def decode_data(line_onset_str, latent_length): |
|
|
""" |
|
|
Extracts a timestamp matrix (event onset indices) from a formatted onset string. |
|
|
|
|
|
Args: |
|
|
line_onset_str (str): String containing event names and onset intervals, |
|
|
formatted like "event1__start1-end1_start2-end2--event2__start1-end1". |
|
|
latent_length (int): Length of the output matrix. |
|
|
|
|
|
Returns: |
|
|
line_onset_index (torch.Tensor): Matrix of shape [4, latent_length], |
|
|
line_event (list): List of event names extracted from the onset string. |
|
|
|
|
|
Notes: |
|
|
- 24000 is the audio sample rate. |
|
|
- 480 is the downsample ratio to align with VAE. |
|
|
- Each onset interval "start-end" (in seconds) is converted to embedding indices via (time * 24000 / 480). |
|
|
""" |
|
|
line_onset_index = torch.zeros((4, latent_length)) |
|
|
line_event = [] |
|
|
event_idx = 0 |
|
|
for event_onset in line_onset_str.split('--'): |
|
|
|
|
|
(event, instance) = event_onset.split('__') |
|
|
|
|
|
line_event.append(event) |
|
|
for start_end in instance.split('_'): |
|
|
(start, end) = start_end.split('-') |
|
|
start, end = int(float(start)*24000/480), int(float(end)*24000/480) |
|
|
if end > (latent_length - 1): break |
|
|
line_onset_index[event_idx, start: end] = 1 |
|
|
event_idx = event_idx + 1 |
|
|
return line_onset_index, line_event |
|
|
|
|
|
|
|
|
class ContentEncoder(nn.Module): |
|
|
""" |
|
|
ContentEncoder encodes TCC and TDC information. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
text_encoder: nn.Module= None, |
|
|
): |
|
|
super().__init__() |
|
|
self.text_encoder = text_encoder |
|
|
self.pool = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
def encode_content( |
|
|
self, batch_content: list[Any], device: str | torch.device |
|
|
): |
|
|
batch_output = [] |
|
|
batch_mask = [] |
|
|
batch_onset = [] |
|
|
length_list = [] |
|
|
print(batch_content) |
|
|
for content in batch_content: |
|
|
|
|
|
caption = content["caption"] |
|
|
onset = content["onset"] |
|
|
length = int(float(content["length"]) *24000/480) |
|
|
|
|
|
print(onset) |
|
|
replace_label = content.get("replace_label", "False") |
|
|
if replace_label == "True": |
|
|
caption, onset = replace_event_synonyms(caption, onset) |
|
|
|
|
|
|
|
|
if content["onset"] == "random": |
|
|
length_list.append(length) |
|
|
""" |
|
|
fixed embedding. Actually it's a sick sentence, a error during training, kept to match the checkpoint. |
|
|
You can change it to sentence that difference to captions in datasets. |
|
|
The use of fixed text to obtain encoding is for numerical stability. |
|
|
We attempted to use learnable unified encoding during training, but the results were not satisfactory. |
|
|
""" |
|
|
event = "There is no event here" |
|
|
event_embed = self.text_encoder([event.replace("_", " ")])["output"] |
|
|
event_embed = self.pool(event_embed.permute(0, 2, 1)) |
|
|
event_embed = event_embed.flatten().unsqueeze(0) |
|
|
new_onset = event_embed.repeat(length, 1).T |
|
|
else: |
|
|
onset_matrix, events = decode_data(onset, length) |
|
|
length_list.append(length) |
|
|
new_onset = torch.zeros((1024, length), device=device) |
|
|
|
|
|
for (idx, event) in enumerate(events): |
|
|
with torch.no_grad(): |
|
|
event_embed = self.text_encoder([event.replace("_", " ")])["output"] |
|
|
event_embed = self.pool(event_embed.permute(0, 2, 1)) |
|
|
event_embed = event_embed.flatten().unsqueeze(0) |
|
|
mask = (onset_matrix[idx, :] == 0) |
|
|
cols = mask.nonzero(as_tuple=True)[0] |
|
|
new_onset[:, cols] += event_embed.T.float() |
|
|
|
|
|
output_dict = self.text_encoder([caption]) |
|
|
batch_output.append(output_dict["output"][0]) |
|
|
batch_mask.append(output_dict["mask"][0]) |
|
|
batch_onset.append(new_onset) |
|
|
|
|
|
|
|
|
batch_output = nn.utils.rnn.pad_sequence( |
|
|
batch_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
batch_mask = nn.utils.rnn.pad_sequence( |
|
|
batch_mask, batch_first=True, padding_value=False |
|
|
) |
|
|
batch_onset = nn.utils.rnn.pad_sequence( |
|
|
batch_onset, batch_first=True, padding_value=0 |
|
|
) |
|
|
return batch_output, batch_mask, batch_onset, length_list |
|
|
|