from typing import Dict, List, Tuple, Union import hydra import omegaconf import torch from torch import Tensor, nn import omegaconf from gigaam.model import GigaAM, GigaAMEmo from gigaam.preprocess import SAMPLE_RATE, load_audio from gigaam.utils import onnx_converter import torch.nn.functional as F from dataclasses import dataclass @dataclass class AudioBatch: wavs: torch.Tensor wav_lengths: torch.Tensor emotions: torch.LongTensor class FeatureExtractorGigaAM(nn.Module): def __init__(self, cfg): super().__init__() #checkpoint = torch.load(path_to_weight, map_location="cpu", weights_only=False) self.fe = GigaAM(cfg) #self.fe.load_state_dict(checkpoint["state_dict"], strict=False) def forward(self, features, feature_lengths): #input raw wavs, attention mask [B, WAV_LEN] return self.fe(features, feature_lengths) # return [B, EMB_DIM, T], [B] class MaxPooling(nn.Module): def __init__(self): super().__init__() def forward(self, features, feature_lengths): # features: [B, T, D] features = features.transpose(1, 2) # → [B, D, T] pooled = F.max_pool1d(features, kernel_size=features.shape[-1]) # → [B, D, 1] return pooled.squeeze(-1) # → [B, D] class AttentionStatsPooling(nn.Module): def __init__(self, input_dim, attn_dim=128): super().__init__() self.attn = nn.Sequential( nn.Linear(input_dim, attn_dim), nn.Tanh(), nn.Linear(attn_dim, 1) ) def forward(self, x, lens): # x: [B, T, D], lens: [B] B, T, D = x.size() device = x.device # [B, T, 1] attn_scores = self.attn(x) # создаём маску: [B, T] mask = torch.arange(T, device=device).unsqueeze(0) < lens.unsqueeze(1) # [B, T] mask = mask.unsqueeze(-1) # [B, T, 1] # маскируем паддинг attn_scores[~mask] = float('-inf') # softmax по валидным позициям attn_weights = F.softmax(attn_scores, dim=1) # [B, T, 1] attn_weights = attn_weights * mask # зануляем padded веса (на всякий случай) # считаем взвешенное среднее и std mean = torch.sum(attn_weights * x, dim=1) # [B, D] std = torch.sqrt(torch.sum(attn_weights * (x - mean.unsqueeze(1))**2, dim=1) + 1e-9) return torch.cat([mean, std], dim=1) # [B, 2*D] class SelfAttentionWithStatsPooling(nn.Module): def __init__(self, embed_dim=768, num_heads=4, attn_dim=128, out_dim=256): super().__init__() self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) self.norm = nn.LayerNorm(embed_dim) self.pool = AttentionStatsPooling(input_dim=embed_dim, attn_dim=attn_dim) self.out = nn.Linear(2 * embed_dim, out_dim) def forward(self, x, lens): # x: [B, T, D], lens: [B] B, T, D = x.size() device = x.device # Attention mask: [B, T] attn_mask = torch.arange(T, device=device).unsqueeze(0) >= lens.unsqueeze(1) # pad == True # Преобразуем для MultiheadAttention: [B, T] → [B, T] → [B, T] → [B, T] (bool) attn_out, _ = self.mha(x, x, x, key_padding_mask=attn_mask) # [B, T, D] attn_out = self.norm(attn_out + x) pooled = self.pool(attn_out, lens) # [B, 2*D] return self.out(pooled) # [B, out_dim] class EmotionModel(nn.Module): def __init__(self, config): super().__init__() self.feature_extractor = FeatureExtractorGigaAM(config.feature_extractor.cfg)#hydra.utils.instantiate(config.feature_extractor) self.pooling = hydra.utils.instantiate(config.pooling) self.head = hydra.utils.instantiate(config.head) def forward(self, batch: AudioBatch): feats, lengths = self.feature_extractor(batch.wavs, batch.wav_lengths) # return [B, EMB_DIM, T], [B] feats = feats.transpose(1, 2)# [B, T, EMB_DIM] pooled = self.pooling(feats, lengths) # [B, NEW_EMB_DIM] if hasattr(self.head, "use_labels_when_train") and self.head.use_labels_when_train is True: logit = self.head(pooled, batch.emotions) else: logit = self.head(pooled) return logit, None # [B, NUM_CLASSES]