Spaces:
Build error
Build error
| from typing import Dict, List, Tuple, Union | |
| import hydra | |
| import omegaconf | |
| import torch | |
| from torch import Tensor, nn | |
| import omegaconf | |
| from gigaam.model import GigaAM, GigaAMEmo | |
| from gigaam.preprocess import SAMPLE_RATE, load_audio | |
| from gigaam.utils import onnx_converter | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| class AudioBatch: | |
| wavs: torch.Tensor | |
| wav_lengths: torch.Tensor | |
| emotions: torch.LongTensor | |
| class FeatureExtractorGigaAM(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| #checkpoint = torch.load(path_to_weight, map_location="cpu", weights_only=False) | |
| self.fe = GigaAM(cfg) | |
| #self.fe.load_state_dict(checkpoint["state_dict"], strict=False) | |
| def forward(self, features, feature_lengths): #input raw wavs, attention mask [B, WAV_LEN] | |
| return self.fe(features, feature_lengths) # return [B, EMB_DIM, T], [B] | |
| class MaxPooling(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, features, feature_lengths): | |
| # features: [B, T, D] | |
| features = features.transpose(1, 2) # → [B, D, T] | |
| pooled = F.max_pool1d(features, kernel_size=features.shape[-1]) # → [B, D, 1] | |
| return pooled.squeeze(-1) # → [B, D] | |
| class AttentionStatsPooling(nn.Module): | |
| def __init__(self, input_dim, attn_dim=128): | |
| super().__init__() | |
| self.attn = nn.Sequential( | |
| nn.Linear(input_dim, attn_dim), | |
| nn.Tanh(), | |
| nn.Linear(attn_dim, 1) | |
| ) | |
| def forward(self, x, lens): # x: [B, T, D], lens: [B] | |
| B, T, D = x.size() | |
| device = x.device | |
| # [B, T, 1] | |
| attn_scores = self.attn(x) | |
| # создаём маску: [B, T] | |
| mask = torch.arange(T, device=device).unsqueeze(0) < lens.unsqueeze(1) # [B, T] | |
| mask = mask.unsqueeze(-1) # [B, T, 1] | |
| # маскируем паддинг | |
| attn_scores[~mask] = float('-inf') | |
| # softmax по валидным позициям | |
| attn_weights = F.softmax(attn_scores, dim=1) # [B, T, 1] | |
| attn_weights = attn_weights * mask # зануляем padded веса (на всякий случай) | |
| # считаем взвешенное среднее и std | |
| mean = torch.sum(attn_weights * x, dim=1) # [B, D] | |
| std = torch.sqrt(torch.sum(attn_weights * (x - mean.unsqueeze(1))**2, dim=1) + 1e-9) | |
| return torch.cat([mean, std], dim=1) # [B, 2*D] | |
| class SelfAttentionWithStatsPooling(nn.Module): | |
| def __init__(self, embed_dim=768, num_heads=4, attn_dim=128, out_dim=256): | |
| super().__init__() | |
| self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.pool = AttentionStatsPooling(input_dim=embed_dim, attn_dim=attn_dim) | |
| self.out = nn.Linear(2 * embed_dim, out_dim) | |
| def forward(self, x, lens): # x: [B, T, D], lens: [B] | |
| B, T, D = x.size() | |
| device = x.device | |
| # Attention mask: [B, T] | |
| attn_mask = torch.arange(T, device=device).unsqueeze(0) >= lens.unsqueeze(1) # pad == True | |
| # Преобразуем для MultiheadAttention: [B, T] → [B, T] → [B, T] → [B, T] (bool) | |
| attn_out, _ = self.mha(x, x, x, key_padding_mask=attn_mask) # [B, T, D] | |
| attn_out = self.norm(attn_out + x) | |
| pooled = self.pool(attn_out, lens) # [B, 2*D] | |
| return self.out(pooled) # [B, out_dim] | |
| class EmotionModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.feature_extractor = FeatureExtractorGigaAM(config.feature_extractor.cfg)#hydra.utils.instantiate(config.feature_extractor) | |
| self.pooling = hydra.utils.instantiate(config.pooling) | |
| self.head = hydra.utils.instantiate(config.head) | |
| def forward(self, batch: AudioBatch): | |
| feats, lengths = self.feature_extractor(batch.wavs, batch.wav_lengths) # return [B, EMB_DIM, T], [B] | |
| feats = feats.transpose(1, 2)# [B, T, EMB_DIM] | |
| pooled = self.pooling(feats, lengths) # [B, NEW_EMB_DIM] | |
| if hasattr(self.head, "use_labels_when_train") and self.head.use_labels_when_train is True: | |
| logit = self.head(pooled, batch.emotions) | |
| else: | |
| logit = self.head(pooled) | |
| return logit, None # [B, NUM_CLASSES] | |