saute / model.py
JustinDuc's picture
Update model.py
d8becc1 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, BertModel, BertTokenizerFast
from transformers.modeling_outputs import MaskedLMOutput
from saute_config import SAUTEConfig
activation_to_class = {
"gelu" : nn.GELU,
"relu" : nn.ReLU,
"sigmoid" : nn.Sigmoid
}
from transformers import AutoModel
class EDUSpeakerAwareMLM(nn.Module):
def __init__(self, config):
super().__init__()
# model_name="sentence-transformers/all-MiniLM-L6-v2"
model_name = "bert-base-uncased"
self.edu_encoder = AutoModel.from_pretrained(model_name)
for param in self.edu_encoder.parameters():
param.requires_grad = False # frozen encoder
self.d_model = config.hidden_size
self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
# self.mlp_proj = nn.Sequential(
# nn.Linear(config.hidden_size, 2048),
# activation_to_class["gelu"](),
# # nn.Dropout(0.1),
# nn.Linear(2048, config.hidden_size),
# # nn.Dropout(0.1),
# )
self.ln1 = nn.LayerNorm(config.hidden_size)
# self.ln2 = nn.LayerNorm(config.hidden_size)
# self.speaker_memory = {} # Will be filled per batch
# self.lm_head = nn.Linear(config.hidden_size, self.edu_encoder.config.vocab_size)
def forward(self, input_ids, attention_mask, speaker_names):
"""
input_ids: (B, T, L)
attention_mask: (B, T, L)
speaker_names: list of list of strings, shape (B, T)
"""
B, T, L = input_ids.shape
# Encode EDUs using frozen encoder
with torch.no_grad():
input_ids_flat = input_ids.view(B * T, L)
attention_mask_flat = attention_mask.view(B * T, L)
outputs = self.edu_encoder(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
token_embeddings = outputs.last_hidden_state # (B*T, L, D)
token_embeddings = token_embeddings.view(B, T, L, self.d_model)
edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
query_emb = self.query_proj(token_embeddings)
# Speaker-aware memory
speaker_memories = [{} for _ in range(B)]
speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
for b in range(B):
for t in range(T):
speaker = speaker_names[b][t]
e_t = edu_embeddings[b, t] # (D)
if speaker not in speaker_memories[b]:
speaker_memories[b][speaker] = {
'kv_sum': torch.zeros(self.d_model, self.d_model, device=e_t.device),
# 'k_sum': torch.zeros(self.d_model, device=e_t.device),
}
mem = speaker_memories[b][speaker]
k_t = self.key_proj(e_t)
v_t = self.val_proj(e_t)
kv_t = torch.outer(k_t, v_t)
# with torch.no_grad():
mem['kv_sum'] = mem['kv_sum'] + kv_t
# mem['k_sum'] = mem['k_sum'] + k_t
# z = torch.clamp(mem['k_sum'] @ k_t, min=1e-6)
# M_s = mem['kv_sum'] / z # (D, D)
# speaker_matrices[b, t] = M_s
speaker_matrices[b, t] = mem['kv_sum']
# Apply speaker matrix to each token
speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
token_embeddings_exp = query_emb.unsqueeze(-1) # (B, T, L, D, 1)
contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
# contextual_tokens = self.ln1(contextual_tokens)
# contextual_tokens = self.ln2(contextual_tokens + self.mlp_proj(contextual_tokens))
# === NEW: EDU-level Transformer ===
edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
encoded_edu = self.transformer(edu_tokens) # (B*T, L, D)
encoded = encoded_edu.view(B, T, L, self.d_model) # (B, T, L, D)
return encoded, 0
class UtteranceEmbedings(PreTrainedModel):
config_class = SAUTEConfig
def __init__(self, config : SAUTEConfig):
super().__init__(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.saute_unit = EDUSpeakerAwareMLM(config)
self.config : SAUTEConfig = config
self.init_weights()
def forward(
self,
input_ids : torch.Tensor,
speaker_names : list[str],
attention_mask : torch.Tensor = None,
labels : torch.Tensor = None
):
# print(input_ids.shape)
X, flop_penalty = self.saute_unit.forward(
input_ids = input_ids,
speaker_names = speaker_names,
attention_mask = attention_mask,
# hidden_state = None
)
# print(X.shape)
logits = self.lm_head(X)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + 1e-3 * flop_penalty
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return MaskedLMOutput(loss=loss, logits=logits)