OpenKWS / model_pinyin.py
ZhiqiAi's picture
Upload 16 files
693e27f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import math
from conformer.conformer.model_def import Conformer
class PositionalEncoding(nn.Module):
"""位置编码模块"""
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class CrossAttention(nn.Module):
"""交叉注意力模块 - 一层交叉注意力加一层自注意力"""
def __init__(self, query_dim, key_dim, heads=4, dropout=0.1):
super(CrossAttention, self).__init__()
# 交叉注意力层
self.cross_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.ffn1 = nn.Sequential(
nn.Linear(query_dim, query_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(query_dim * 4, query_dim),
nn.Dropout(dropout)
)
# 自注意力层
self.self_attn = nn.MultiheadAttention(query_dim, heads, dropout=dropout, batch_first=True)
self.norm3 = nn.LayerNorm(query_dim)
self.norm4 = nn.LayerNorm(query_dim)
self.ffn2 = nn.Sequential(
nn.Linear(query_dim, query_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(query_dim * 4, query_dim),
nn.Dropout(dropout)
)
# 投影层
self.proj_key = nn.Linear(key_dim, query_dim)
self.proj_value = nn.Linear(key_dim, query_dim)
def forward(self, query, key, value, key_padding_mask=None):
# 有音频输入时,先进行交叉注意力
# 投影key和value到query的维度
key_proj = self.proj_key(key)
value_proj = self.proj_value(value)
# 交叉注意力
query_norm = self.norm1(query)
cross_attn_output, _ = self.cross_attn(query_norm, key_proj, value_proj,
key_padding_mask=key_padding_mask)
query = query + cross_attn_output
query = query + self.ffn1(self.norm2(query))
# 然后进行自注意力
query_norm = self.norm3(query)
self_attn_output, _ = self.self_attn(query_norm, query_norm, query_norm)
query = query + self_attn_output
query = query + self.ffn2(self.norm4(query))
return query
class MMKWS2(nn.Module):
def __init__(
self,
# anchor
text_dim=64,
audio_dim=1024,
hidden_dim=128,
# compare
dim=80,
encoder_dim=128,
num_encoder_layers=6,
num_attention_heads=4,
dropout=0.1,
num_transformer_layers=2
):
super(MMKWS2, self).__init__()
# 音频嵌入降维
self.audio_proj = nn.Linear(audio_dim, hidden_dim)
# 文本嵌入投影
self.text_proj = nn.Embedding(num_embeddings=402, embedding_dim=hidden_dim) # 401 + padding -1
# 位置编码
self.pos_enc = PositionalEncoding(hidden_dim)
# 交叉注意力模块
self.cross_attn = CrossAttention(hidden_dim, hidden_dim, heads=num_attention_heads, dropout=dropout)
# Conformer层
self.conformer = Conformer(
input_dim=dim,
encoder_dim=encoder_dim,
num_encoder_layers=num_encoder_layers,
num_attention_heads=num_attention_heads,
)
# 特征映射层(将conformer输出维度映射到hidden_dim)
self.feat_proj = nn.Linear(encoder_dim, hidden_dim)
# Transformer层
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_attention_heads,
dim_feedforward=hidden_dim*4,
dropout=dropout,
batch_first=True
),
num_layers=num_transformer_layers
)
# GRU分类器
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim*2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1)
)
# 序列标签预测
self.seq_classifier = nn.Linear(hidden_dim, 1)
def forward(self, anchor_wave_embedding, anchor_text_embedding, compare_wave, compare_lengths):
batch_size = anchor_wave_embedding.size(0)
# 1. 处理anchor_text嵌入
text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
text_feat = self.pos_enc(text_feat)
# 2. 处理anchor_wave音频嵌入
audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
audio_feat = self.pos_enc(audio_feat)
# 3. 交叉注意力:文本和音频特征融合
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
# 4. 处理compare_wave的fbank特征
compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
compare_feat = self.pos_enc(compare_feat)
# 5. 合并特征
text_len = fused_feat.size(1)
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
# 7. GRU分类
gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
# 全局分类
global_feat = gru_out[:, -1, :] # 取最后一个时间步
logits = self.classifier(global_feat).squeeze(-1) # [B]
# 序列标签预测
seq_logits = self.seq_classifier(combined_feat[:, :text_len, :]).squeeze(-1) # [B, S]
return logits, seq_logits
def enrollment(self, anchor_wave_embedding, anchor_text_embedding):
batch_size = anchor_wave_embedding.size(0)
# 1. 处理anchor_text嵌入
text_feat = self.text_proj(anchor_text_embedding) # [B, S, hidden_dim]
text_feat = self.pos_enc(text_feat)
# 2. 处理anchor_wave音频嵌入
audio_feat = self.audio_proj(anchor_wave_embedding) # [B, S, hidden_dim]
audio_feat = self.pos_enc(audio_feat)
# 3. 交叉注意力:文本和音频特征融合
fused_feat = self.cross_attn(text_feat, audio_feat, audio_feat, key_padding_mask=None) # [B, S, hidden_dim]
return fused_feat
def verification(self, fused_feat, compare_wave, compare_lengths):
batch_size = fused_feat.size(0)
# 4. 处理compare_wave的fbank特征
compare_feat = self.conformer(compare_wave, compare_lengths)[0] # [B, T, encoder_dim]
compare_feat = self.feat_proj(compare_feat) # [B, T, hidden_dim]
compare_feat = self.pos_enc(compare_feat)
# 5. 合并特征
text_len = fused_feat.size(1)
combined_feat = torch.cat([fused_feat, compare_feat], dim=1) # [B, S+T, hidden_dim]
combined_feat = self.transformer_encoder(combined_feat) # [B, S+T, hidden_dim]
# 7. GRU分类
gru_out, _ = self.gru(combined_feat) # [B, S+T, hidden_dim*2]
# 全局分类
global_feat = gru_out[:, -1, :] # 取最后一个时间步
logits = self.classifier(global_feat).squeeze(-1) # [B]
return logits
def count_verification_params(model):
modules = [
model.conformer,
model.feat_proj,
model.transformer_encoder,
model.gru,
model.classifier
]
total = 0
for m in modules:
total += sum(p.numel() for p in m.parameters())
return total
model = MMKWS2(
text_dim=64,
audio_dim=1024,
hidden_dim=128,
dim=80,
encoder_dim=128,
num_encoder_layers=6,
num_attention_heads=4,
dropout=0.1,
num_transformer_layers=2
)
print(f"verification相关参数量: {count_verification_params(model):,}") # 3.5M模型参数量
# if __name__ == "__main__":
# # 创建一个示例batch
# batch_size = 2
# # 创建模拟数据
# anchor_embedding = torch.randn(batch_size, 8, 64) # 文本嵌入
# anchor_wave = torch.randn(batch_size, 256, 1024) # 音频嵌入
# compare_wave = torch.randn(batch_size, 45, 80) # Fbank特征
# # 创建长度信息
# anchor_lengths = torch.LongTensor([8, 6]) # 两个样本的实际长度
# compare_lengths = torch.LongTensor([45, 40])
# # 创建模型
# model = MMKWS2(
# text_dim=64,
# audio_dim=1024,
# hidden_dim=128,
# dim=80,
# encoder_dim=128,
# num_encoder_layers=6,
# num_attention_heads=4,
# dropout=0.1,
# num_transformer_layers=2
# )
# # 计算模型参数量
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"模型总参数量: {total_params:,}")
# print(f"可训练参数量: {trainable_params:,}")
# # 打印模型结构
# print("\n模型结构:")
# print(model)
# # 模型推理
# print("\n开始推理...")
# model.eval()
# with torch.no_grad():
# print(anchor_embedding.shape)
# print(anchor_wave.shape)
# print(compare_wave.shape)
# print(anchor_lengths.shape)
# print(compare_lengths.shape)
# # 完整输入推理
# logits, seq_logits, text_len = model(
# anchor_embedding=anchor_embedding,
# anchor_wave=anchor_wave,
# compare_wave=compare_wave,
# anchor_lengths=anchor_lengths,
# compare_lengths=compare_lengths
# )
# print("\n推理结果:")
# print(f"分类logits形状: {logits.shape}")
# print(f"序列logits形状: {seq_logits.shape}")