Ilia
create spaces
c485839
Raw
History Blame Contribute Delete
4.43 kB
from typing import Dict, List, Tuple, Union
import hydra
import omegaconf
import torch
from torch import Tensor, nn
import omegaconf
from gigaam.model import GigaAM, GigaAMEmo
from gigaam.preprocess import SAMPLE_RATE, load_audio
from gigaam.utils import onnx_converter
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class AudioBatch:
wavs: torch.Tensor
wav_lengths: torch.Tensor
emotions: torch.LongTensor
class FeatureExtractorGigaAM(nn.Module):
def __init__(self, cfg):
super().__init__()
#checkpoint = torch.load(path_to_weight, map_location="cpu", weights_only=False)
self.fe = GigaAM(cfg)
#self.fe.load_state_dict(checkpoint["state_dict"], strict=False)
def forward(self, features, feature_lengths): #input raw wavs, attention mask [B, WAV_LEN]
return self.fe(features, feature_lengths) # return [B, EMB_DIM, T], [B]
class MaxPooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, features, feature_lengths):
# features: [B, T, D]
features = features.transpose(1, 2) # → [B, D, T]
pooled = F.max_pool1d(features, kernel_size=features.shape[-1]) # → [B, D, 1]
return pooled.squeeze(-1) # → [B, D]
class AttentionStatsPooling(nn.Module):
def __init__(self, input_dim, attn_dim=128):
super().__init__()
self.attn = nn.Sequential(
nn.Linear(input_dim, attn_dim),
nn.Tanh(),
nn.Linear(attn_dim, 1)
)
def forward(self, x, lens): # x: [B, T, D], lens: [B]
B, T, D = x.size()
device = x.device
# [B, T, 1]
attn_scores = self.attn(x)
# создаём маску: [B, T]
mask = torch.arange(T, device=device).unsqueeze(0) < lens.unsqueeze(1) # [B, T]
mask = mask.unsqueeze(-1) # [B, T, 1]
# маскируем паддинг
attn_scores[~mask] = float('-inf')
# softmax по валидным позициям
attn_weights = F.softmax(attn_scores, dim=1) # [B, T, 1]
attn_weights = attn_weights * mask # зануляем padded веса (на всякий случай)
# считаем взвешенное среднее и std
mean = torch.sum(attn_weights * x, dim=1) # [B, D]
std = torch.sqrt(torch.sum(attn_weights * (x - mean.unsqueeze(1))**2, dim=1) + 1e-9)
return torch.cat([mean, std], dim=1) # [B, 2*D]
class SelfAttentionWithStatsPooling(nn.Module):
def __init__(self, embed_dim=768, num_heads=4, attn_dim=128, out_dim=256):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
self.pool = AttentionStatsPooling(input_dim=embed_dim, attn_dim=attn_dim)
self.out = nn.Linear(2 * embed_dim, out_dim)
def forward(self, x, lens): # x: [B, T, D], lens: [B]
B, T, D = x.size()
device = x.device
# Attention mask: [B, T]
attn_mask = torch.arange(T, device=device).unsqueeze(0) >= lens.unsqueeze(1) # pad == True
# Преобразуем для MultiheadAttention: [B, T] → [B, T] → [B, T] → [B, T] (bool)
attn_out, _ = self.mha(x, x, x, key_padding_mask=attn_mask) # [B, T, D]
attn_out = self.norm(attn_out + x)
pooled = self.pool(attn_out, lens) # [B, 2*D]
return self.out(pooled) # [B, out_dim]
class EmotionModel(nn.Module):
def __init__(self, config):
super().__init__()
self.feature_extractor = FeatureExtractorGigaAM(config.feature_extractor.cfg)#hydra.utils.instantiate(config.feature_extractor)
self.pooling = hydra.utils.instantiate(config.pooling)
self.head = hydra.utils.instantiate(config.head)
def forward(self, batch: AudioBatch):
feats, lengths = self.feature_extractor(batch.wavs, batch.wav_lengths) # return [B, EMB_DIM, T], [B]
feats = feats.transpose(1, 2)# [B, T, EMB_DIM]
pooled = self.pooling(feats, lengths) # [B, NEW_EMB_DIM]
if hasattr(self.head, "use_labels_when_train") and self.head.use_labels_when_train is True:
logit = self.head(pooled, batch.emotions)
else:
logit = self.head(pooled)
return logit, None # [B, NUM_CLASSES]