|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pytorch_lightning as pl |
|
|
from torch.optim import Adam |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
import math |
|
|
from conformer.conformer.model_def import Conformer |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""位置编码模块""" |
|
|
def __init__(self, d_model, max_len=5000): |
|
|
super(PositionalEncoding, self).__init__() |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x): |
|
|
return x + self.pe[:, :x.size(1)] |
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
"""交叉注意力模块 - 一层交叉注意力加一层自注意力""" |
|
|
def __init__(self, query_dim, key_dim, heads=4, dropout=0.1): |
|
|
super(CrossAttention, self).__init__() |
|
|
|
|
|
self.cross_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True) |
|
|
self.norm1 = nn.LayerNorm(query_dim) |
|
|
self.norm2 = nn.LayerNorm(query_dim) |
|
|
self.ffn1 = nn.Sequential( |
|
|
nn.Linear(query_dim, query_dim * 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(query_dim * 4, query_dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True) |
|
|
self.norm3 = nn.LayerNorm(query_dim) |
|
|
self.norm4 = nn.LayerNorm(query_dim) |
|
|
self.ffn2 = nn.Sequential( |
|
|
nn.Linear(query_dim, query_dim * 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(query_dim * 4, query_dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.proj_key = nn.Linear(key_dim, query_dim) |
|
|
self.proj_value = nn.Linear(key_dim, query_dim) |
|
|
|
|
|
def forward(self, query, key, value, key_padding_mask=None): |
|
|
|
|
|
|
|
|
key_proj = self.proj_key(key) |
|
|
value_proj = self.proj_value(value) |
|
|
|
|
|
|
|
|
query_norm = self.norm1(query) |
|
|
cross_attn_output, _ = self.cross_attn(query_norm, key_proj, value_proj, |
|
|
key_padding_mask=key_padding_mask) |
|
|
query = query + cross_attn_output |
|
|
query = query + self.ffn1(self.norm2(query)) |
|
|
|
|
|
|
|
|
query_norm = self.norm3(query) |
|
|
self_attn_output, _ = self.self_attn(query_norm, query_norm, query_norm) |
|
|
query = query + self_attn_output |
|
|
query = query + self.ffn2(self.norm4(query)) |
|
|
|
|
|
return query |
|
|
|
|
|
|
|
|
class MMKWS2(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
text_dim=64, |
|
|
audio_dim=1024, |
|
|
hidden_dim=128, |
|
|
|
|
|
dim=80, |
|
|
encoder_dim=128, |
|
|
num_encoder_layers=6, |
|
|
num_attention_heads=4, |
|
|
dropout=0.1, |
|
|
num_transformer_layers=2 |
|
|
): |
|
|
super(MMKWS2, self).__init__() |
|
|
|
|
|
self.audio_proj = nn.Linear(audio_dim, hidden_dim) |
|
|
|
|
|
self.text_proj = nn.Embedding(num_embeddings=402, embedding_dim=hidden_dim) |
|
|
|
|
|
self.pos_enc = PositionalEncoding(hidden_dim) |
|
|
|
|
|
self.cross_attn = CrossAttention(hidden_dim, hidden_dim, heads=num_attention_heads, dropout=dropout) |
|
|
|
|
|
|
|
|
self.conformer = Conformer( |
|
|
input_dim=dim, |
|
|
encoder_dim=encoder_dim, |
|
|
num_encoder_layers=num_encoder_layers, |
|
|
num_attention_heads=num_attention_heads, |
|
|
) |
|
|
|
|
|
|
|
|
self.feat_proj = nn.Linear(encoder_dim, hidden_dim) |
|
|
|
|
|
|
|
|
self.transformer_encoder = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=num_attention_heads, |
|
|
dim_feedforward=hidden_dim*4, |
|
|
dropout=dropout, |
|
|
batch_first=True |
|
|
), |
|
|
num_layers=num_transformer_layers |
|
|
) |
|
|
|
|
|
|
|
|
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(hidden_dim*2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.seq_classifier = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
def forward(self, anchor_wave_embedding, anchor_text_embedding, compare_wave, compare_lengths): |
|
|
batch_size = anchor_wave_embedding.size(0) |
|
|
|
|
|
|
|
|
text_feat = self.text_proj(anchor_text_embedding) |
|
|
text_feat = self.pos_enc(text_feat) |
|
|
|
|
|
|
|
|
audio_feat = self.audio_proj(anchor_wave_embedding) |
|
|
audio_feat = self.pos_enc(audio_feat) |
|
|
|
|
|
|
|
|
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) |
|
|
|
|
|
|
|
|
compare_feat = self.conformer(compare_wave, compare_lengths)[0] |
|
|
compare_feat = self.feat_proj(compare_feat) |
|
|
compare_feat = self.pos_enc(compare_feat) |
|
|
|
|
|
|
|
|
text_len = fused_feat.size(1) |
|
|
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) |
|
|
combined_feat = self.transformer_encoder(combined_feat) |
|
|
|
|
|
|
|
|
gru_out, _ = self.gru(combined_feat) |
|
|
|
|
|
|
|
|
global_feat = gru_out[:, -1, :] |
|
|
logits = self.classifier(global_feat).squeeze(-1) |
|
|
|
|
|
|
|
|
seq_logits = self.seq_classifier(combined_feat[:, :text_len, :]).squeeze(-1) |
|
|
return logits, seq_logits |
|
|
|
|
|
|
|
|
def enrollment(self, anchor_wave_embedding, anchor_text_embedding): |
|
|
batch_size = anchor_wave_embedding.size(0) |
|
|
|
|
|
|
|
|
text_feat = self.text_proj(anchor_text_embedding) |
|
|
text_feat = self.pos_enc(text_feat) |
|
|
|
|
|
|
|
|
audio_feat = self.audio_proj(anchor_wave_embedding) |
|
|
audio_feat = self.pos_enc(audio_feat) |
|
|
|
|
|
|
|
|
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) |
|
|
|
|
|
return fused_feat |
|
|
|
|
|
def verification(self, fused_feat, compare_wave, compare_lengths): |
|
|
batch_size = fused_feat.size(0) |
|
|
|
|
|
|
|
|
compare_feat = self.conformer(compare_wave, compare_lengths)[0] |
|
|
compare_feat = self.feat_proj(compare_feat) |
|
|
compare_feat = self.pos_enc(compare_feat) |
|
|
|
|
|
|
|
|
text_len = fused_feat.size(1) |
|
|
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) |
|
|
combined_feat = self.transformer_encoder(combined_feat) |
|
|
|
|
|
|
|
|
gru_out, _ = self.gru(combined_feat) |
|
|
|
|
|
|
|
|
global_feat = gru_out[:, -1, :] |
|
|
logits = self.classifier(global_feat).squeeze(-1) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def count_verification_params(model): |
|
|
modules = [ |
|
|
model.conformer, |
|
|
model.feat_proj, |
|
|
model.transformer_encoder, |
|
|
model.gru, |
|
|
model.classifier |
|
|
] |
|
|
total = 0 |
|
|
for m in modules: |
|
|
total += sum(p.numel() for p in m.parameters()) |
|
|
return total |
|
|
|
|
|
model = MMKWS2( |
|
|
text_dim=64, |
|
|
audio_dim=1024, |
|
|
hidden_dim=128, |
|
|
dim=80, |
|
|
encoder_dim=128, |
|
|
num_encoder_layers=6, |
|
|
num_attention_heads=4, |
|
|
dropout=0.1, |
|
|
num_transformer_layers=2 |
|
|
) |
|
|
print(f"verification相关参数量: {count_verification_params(model):,}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|