Spaces:
Running
Running
| import torch.nn as nn | |
| from modules.FGA.atten import Atten | |
| class FGAEmbedder(nn.Module): | |
| def __init__(self, input_size=768*3, output_size=768): | |
| super(FGAEmbedder, self).__init__() | |
| self.fc1 = nn.Linear(input_size, input_size) | |
| self.fc2 = nn.Linear(input_size, output_size) | |
| self.gelu = nn.GELU() | |
| self.fga = Atten(util_e=[output_size], pairwise_flag=False) | |
| def forward(self, audio_embs): | |
| audio_embs = self.fc1(audio_embs) | |
| audio_embs = self.gelu(audio_embs) | |
| audio_embs = self.fc2(audio_embs) | |
| attend = self.fga([audio_embs])[0] | |
| return attend | |