| from config_loconet import LoCoNetConfig |
| from transformers import PreTrainedModel |
| from loconet_encoder import locoencoder |
| from loss_multi import lossAV, lossA, lossV |
|
|
|
|
| class loconet(PreTrainedModel): |
| config_class = LoCoNetConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.model = locoencoder(config) |
| self.lossAV = lossAV() |
| self.lossA = lossA() |
| self.lossV = lossV() |
|
|
| def forward(self, audioFeature, visualFeature, masks, labels=None): |
| b, s, t = visualFeature.shape[:3] |
| visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) |
| labels = labels.view(b * s, *labels.shape[2:]) |
| masks = masks.view(b * s, *masks.shape[2:]) |
|
|
| audioEmbed = self.model.forward_audio_frontend(audioFeature) |
| visualEmbed = self.model.forward_visual_frontend(visualFeature) |
| audioEmbed = audioEmbed.repeat(s, 1, 1) |
|
|
| audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed) |
| outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) |
| outsA = self.model.forward_audio_backend(audioEmbed) |
| outsV = self.model.forward_visual_backend(visualEmbed) |
| num_frames = masks.sum() |
|
|
| if labels is not None: |
|
|
| labels = labels.reshape((-1)) |
| masks = masks.reshape((-1)) |
| nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks) |
| nlossA = self.lossA.forward(outsA, labels, masks) |
| nlossV = self.lossV.forward(outsV, labels, masks) |
|
|
| nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV |
|
|
| return {"loss": nloss, "logits": outsAV} |
|
|
| else: |
|
|
| return {"logits": outsAV} |
|
|