# core/voice_encoder.py import torch import torch.nn as nn class VoiceEncoder(nn.Module): """ LSTM-based voice encoder. Input: (batch, frames, feat_dim) Output: (batch, embed_dim) """ def __init__(self, feat_dim=80, embed_dim=80): super().__init__() self.lstm = nn.LSTM( input_size=feat_dim, hidden_size=embed_dim, batch_first=True ) self.fc = nn.Linear(embed_dim, embed_dim) def forward(self, x): _, (h, _) = self.lstm(x) return self.fc(h[-1])