File size: 4,428 Bytes
c485839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


from typing import Dict, List, Tuple, Union

import hydra
import omegaconf
import torch
from torch import Tensor, nn


import omegaconf
from gigaam.model import GigaAM, GigaAMEmo
from gigaam.preprocess import SAMPLE_RATE, load_audio
from gigaam.utils import onnx_converter

import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class AudioBatch:
    wavs: torch.Tensor
    wav_lengths: torch.Tensor
    emotions: torch.LongTensor

class FeatureExtractorGigaAM(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        #checkpoint = torch.load(path_to_weight, map_location="cpu", weights_only=False)
  
        self.fe = GigaAM(cfg)
        #self.fe.load_state_dict(checkpoint["state_dict"], strict=False)
    def forward(self, features, feature_lengths): #input raw wavs, attention mask [B, WAV_LEN]
        return self.fe(features, feature_lengths) # return [B, EMB_DIM, T], [B]
     
class MaxPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, features, feature_lengths):
        # features: [B, T, D]
        features = features.transpose(1, 2)  # → [B, D, T]
        pooled = F.max_pool1d(features, kernel_size=features.shape[-1])  # → [B, D, 1]
        return pooled.squeeze(-1)  # → [B, D]


class AttentionStatsPooling(nn.Module):
    def __init__(self, input_dim, attn_dim=128):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(input_dim, attn_dim),
            nn.Tanh(),
            nn.Linear(attn_dim, 1)
        )

    def forward(self, x, lens):  # x: [B, T, D], lens: [B]
        B, T, D = x.size()
        device = x.device

        # [B, T, 1]
        attn_scores = self.attn(x)

        # создаём маску: [B, T]
        mask = torch.arange(T, device=device).unsqueeze(0) < lens.unsqueeze(1)  # [B, T]
        mask = mask.unsqueeze(-1)  # [B, T, 1]

        # маскируем паддинг
        attn_scores[~mask] = float('-inf')

        # softmax по валидным позициям
        attn_weights = F.softmax(attn_scores, dim=1)  # [B, T, 1]
        attn_weights = attn_weights * mask  # зануляем padded веса (на всякий случай)

        # считаем взвешенное среднее и std
        mean = torch.sum(attn_weights * x, dim=1)  # [B, D]
        std = torch.sqrt(torch.sum(attn_weights * (x - mean.unsqueeze(1))**2, dim=1) + 1e-9)

        return torch.cat([mean, std], dim=1)  # [B, 2*D]


class SelfAttentionWithStatsPooling(nn.Module):
    def __init__(self, embed_dim=768, num_heads=4, attn_dim=128, out_dim=256):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.pool = AttentionStatsPooling(input_dim=embed_dim, attn_dim=attn_dim)
        self.out = nn.Linear(2 * embed_dim, out_dim)

    def forward(self, x, lens):  # x: [B, T, D], lens: [B]
        B, T, D = x.size()
        device = x.device

        # Attention mask: [B, T]
        attn_mask = torch.arange(T, device=device).unsqueeze(0) >= lens.unsqueeze(1)  # pad == True

        # Преобразуем для MultiheadAttention: [B, T] → [B, T] → [B, T] → [B, T] (bool)

        attn_out, _ = self.mha(x, x, x, key_padding_mask=attn_mask)  # [B, T, D]
        attn_out = self.norm(attn_out + x)

        pooled = self.pool(attn_out, lens)  # [B, 2*D]
        return self.out(pooled)             # [B, out_dim]




class EmotionModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.feature_extractor = FeatureExtractorGigaAM(config.feature_extractor.cfg)#hydra.utils.instantiate(config.feature_extractor)
        self.pooling = hydra.utils.instantiate(config.pooling)
        self.head = hydra.utils.instantiate(config.head)
    
    def forward(self, batch: AudioBatch):
        feats, lengths = self.feature_extractor(batch.wavs, batch.wav_lengths) # return  [B, EMB_DIM, T], [B]
        feats = feats.transpose(1, 2)# [B, T, EMB_DIM]

        pooled = self.pooling(feats, lengths) # [B, NEW_EMB_DIM]
        if hasattr(self.head, "use_labels_when_train") and self.head.use_labels_when_train is True:
            logit = self.head(pooled, batch.emotions)
        else:
            logit = self.head(pooled)
        return logit, None # [B, NUM_CLASSES]