| | from config_loconet import LoCoNetConfig |
| | from transformers import PreTrainedModel |
| | from loconet_encoder import locoencoder |
| | from loss_multi import lossAV, lossA, lossV |
| |
|
| |
|
| | class loconet(PreTrainedModel): |
| | config_class = LoCoNetConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.model = locoencoder(config) |
| | self.lossAV = lossAV() |
| | self.lossA = lossA() |
| | self.lossV = lossV() |
| |
|
| | def forward(self, audioFeature, visualFeature, masks, labels=None): |
| | b, s, t = visualFeature.shape[:3] |
| | visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) |
| | labels = labels.view(b * s, *labels.shape[2:]) |
| | masks = masks.view(b * s, *masks.shape[2:]) |
| |
|
| | audioEmbed = self.model.forward_audio_frontend(audioFeature) |
| | visualEmbed = self.model.forward_visual_frontend(visualFeature) |
| | audioEmbed = audioEmbed.repeat(s, 1, 1) |
| |
|
| | audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed) |
| | outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) |
| | outsA = self.model.forward_audio_backend(audioEmbed) |
| | outsV = self.model.forward_visual_backend(visualEmbed) |
| | num_frames = masks.sum() |
| |
|
| | if labels is not None: |
| |
|
| | labels = labels.reshape((-1)) |
| | masks = masks.reshape((-1)) |
| | nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks) |
| | nlossA = self.lossA.forward(outsA, labels, masks) |
| | nlossV = self.lossV.forward(outsV, labels, masks) |
| |
|
| | nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV |
| |
|
| | return {"loss": nloss, "logits": outsAV} |
| |
|
| | else: |
| |
|
| | return {"logits": outsAV} |
| |
|