Spaces:
Build error
Build error
File size: 4,428 Bytes
c485839 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from typing import Dict, List, Tuple, Union
import hydra
import omegaconf
import torch
from torch import Tensor, nn
import omegaconf
from gigaam.model import GigaAM, GigaAMEmo
from gigaam.preprocess import SAMPLE_RATE, load_audio
from gigaam.utils import onnx_converter
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class AudioBatch:
wavs: torch.Tensor
wav_lengths: torch.Tensor
emotions: torch.LongTensor
class FeatureExtractorGigaAM(nn.Module):
def __init__(self, cfg):
super().__init__()
#checkpoint = torch.load(path_to_weight, map_location="cpu", weights_only=False)
self.fe = GigaAM(cfg)
#self.fe.load_state_dict(checkpoint["state_dict"], strict=False)
def forward(self, features, feature_lengths): #input raw wavs, attention mask [B, WAV_LEN]
return self.fe(features, feature_lengths) # return [B, EMB_DIM, T], [B]
class MaxPooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, features, feature_lengths):
# features: [B, T, D]
features = features.transpose(1, 2) # → [B, D, T]
pooled = F.max_pool1d(features, kernel_size=features.shape[-1]) # → [B, D, 1]
return pooled.squeeze(-1) # → [B, D]
class AttentionStatsPooling(nn.Module):
def __init__(self, input_dim, attn_dim=128):
super().__init__()
self.attn = nn.Sequential(
nn.Linear(input_dim, attn_dim),
nn.Tanh(),
nn.Linear(attn_dim, 1)
)
def forward(self, x, lens): # x: [B, T, D], lens: [B]
B, T, D = x.size()
device = x.device
# [B, T, 1]
attn_scores = self.attn(x)
# создаём маску: [B, T]
mask = torch.arange(T, device=device).unsqueeze(0) < lens.unsqueeze(1) # [B, T]
mask = mask.unsqueeze(-1) # [B, T, 1]
# маскируем паддинг
attn_scores[~mask] = float('-inf')
# softmax по валидным позициям
attn_weights = F.softmax(attn_scores, dim=1) # [B, T, 1]
attn_weights = attn_weights * mask # зануляем padded веса (на всякий случай)
# считаем взвешенное среднее и std
mean = torch.sum(attn_weights * x, dim=1) # [B, D]
std = torch.sqrt(torch.sum(attn_weights * (x - mean.unsqueeze(1))**2, dim=1) + 1e-9)
return torch.cat([mean, std], dim=1) # [B, 2*D]
class SelfAttentionWithStatsPooling(nn.Module):
def __init__(self, embed_dim=768, num_heads=4, attn_dim=128, out_dim=256):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
self.pool = AttentionStatsPooling(input_dim=embed_dim, attn_dim=attn_dim)
self.out = nn.Linear(2 * embed_dim, out_dim)
def forward(self, x, lens): # x: [B, T, D], lens: [B]
B, T, D = x.size()
device = x.device
# Attention mask: [B, T]
attn_mask = torch.arange(T, device=device).unsqueeze(0) >= lens.unsqueeze(1) # pad == True
# Преобразуем для MultiheadAttention: [B, T] → [B, T] → [B, T] → [B, T] (bool)
attn_out, _ = self.mha(x, x, x, key_padding_mask=attn_mask) # [B, T, D]
attn_out = self.norm(attn_out + x)
pooled = self.pool(attn_out, lens) # [B, 2*D]
return self.out(pooled) # [B, out_dim]
class EmotionModel(nn.Module):
def __init__(self, config):
super().__init__()
self.feature_extractor = FeatureExtractorGigaAM(config.feature_extractor.cfg)#hydra.utils.instantiate(config.feature_extractor)
self.pooling = hydra.utils.instantiate(config.pooling)
self.head = hydra.utils.instantiate(config.head)
def forward(self, batch: AudioBatch):
feats, lengths = self.feature_extractor(batch.wavs, batch.wav_lengths) # return [B, EMB_DIM, T], [B]
feats = feats.transpose(1, 2)# [B, T, EMB_DIM]
pooled = self.pooling(feats, lengths) # [B, NEW_EMB_DIM]
if hasattr(self.head, "use_labels_when_train") and self.head.use_labels_when_train is True:
logit = self.head(pooled, batch.emotions)
else:
logit = self.head(pooled)
return logit, None # [B, NUM_CLASSES]
|