saute / saute_model.py
JustinDuc's picture
Update saute_model.py
f31997e verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, BertModel, BertTokenizerFast
from transformers.modeling_outputs import MaskedLMOutput
from .saute_config import SAUTEConfig
activation_to_class = {
"gelu" : nn.GELU,
"relu" : nn.ReLU,
"sigmoid" : nn.Sigmoid
}
from transformers import AutoModel
class EDUSpeakerAwareMLM(nn.Module):
def __init__(self, config):
super().__init__()
# model_name="sentence-transformers/all-MiniLM-L6-v2"
model_name = "bert-base-uncased"
self.edu_encoder = AutoModel.from_pretrained(model_name)
for param in self.edu_encoder.parameters():
param.requires_grad = False # frozen encoder
self.d_model = config.hidden_size
# self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
# self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
# self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
# self.mlp_proj = nn.Sequential(
# nn.Linear(config.hidden_size, 2048),
# activation_to_class["gelu"](),
# # nn.Dropout(0.1),
# nn.Linear(2048, config.hidden_size),
# # nn.Dropout(0.1),
# )
self.ln1 = nn.LayerNorm(config.hidden_size)
# self.ln2 = nn.LayerNorm(config.hidden_size)
# self.speaker_memory = {} # Will be filled per batch
# self.lm_head = nn.Linear(config.hidden_size, self.edu_encoder.config.vocab_size)
def forward(self, input_ids, attention_mask, speaker_names):
"""
input_ids: (B, T, L)
attention_mask: (B, T, L)
speaker_names: list of list of strings, shape (B, T)
"""
B, T, L = input_ids.shape
# Encode EDUs using frozen encoder
with torch.no_grad():
input_ids_flat = input_ids.view(B * T, L)
attention_mask_flat = attention_mask.view(B * T, L)
outputs = self.edu_encoder(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
token_embeddings = outputs.last_hidden_state # (B*T, L, D)
token_embeddings = token_embeddings.view(B, T, L, self.d_model)
# edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
edu_embeddings = token_embeddings[:,:,0] # CLS token
# query_emb = self.query_proj(token_embeddings)
# Speaker-aware memory
speaker_memories = [{} for _ in range(B)]
# speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
H = self.num_heads
d = self.head_dim
speaker_matrices = torch.zeros(B, T, H, d, d, device=edu_embeddings.device)
for b in range(B):
for t in range(T):
speaker = speaker_names[b][t]
e_t = edu_embeddings[b, t] # (D)
if speaker not in speaker_memories[b]:
# speaker_memories[b][speaker] = {
# 'kv_sum': torch.zeros(self.d_model, self.d_model, device=e_t.device),
# # 'k_sum': torch.zeros(self.d_model, device=e_t.device),
# }
speaker_memories[b][speaker] = {
'kv_sum': torch.zeros(self.num_heads, self.head_dim, self.head_dim, device=e_t.device)
}
mem = speaker_memories[b][speaker]
# k_t = self.key_proj(e_t)
# v_t = self.val_proj(e_t)
# kv_t = torch.outer(k_t, v_t)
k_t = self.key_proj(e_t).view(self.num_heads, self.head_dim) # (H, d_k)
v_t = self.val_proj(e_t).view(self.num_heads, self.head_dim) # (H, d_v)
kv_t = torch.einsum("hd,he->hde", k_t, v_t) # (H, d_k, d_v)
# with torch.no_grad():
mem['kv_sum'] = mem['kv_sum'] + kv_t
# mem['k_sum'] = mem['k_sum'] + k_t
# z = torch.clamp(mem['k_sum'] @ k_t, min=1e-6)
# M_s = mem['kv_sum'] / z # (D, D)
# speaker_matrices[b, t] = M_s
speaker_matrices[b, t] = mem['kv_sum']
# Apply speaker matrix to each token
# speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
# token_embeddings_exp = query_emb.unsqueeze(-1) # (B, T, L, D, 1)
# contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
# contextual_tokens = self.ln1(contextual_tokens)
# contextual_tokens = self.ln2(contextual_tokens + self.mlp_proj(contextual_tokens))
# Project queries
query_emb = self.query_proj(token_embeddings) # (B, T, L, D)
query = query_emb.view(B, T, L, H, d) # (B, T, L, H, d)
# Apply memory matrices
contextual = []
for b in range(B):
head_outputs = []
for t in range(T):
speaker = speaker_names[b][t]
M = speaker_matrices[b, t] # (H, d, d)
q = query[b, t] # (L, H, d)
q = q.transpose(0, 1) # (H, L, d)
a = torch.matmul(q, M) # (H, L, d)
a = a.transpose(0, 1).contiguous().view(L, -1) # (L, D)
contextual_token = token_embeddings[b, t] + a
head_outputs.append(contextual_token)
contextual.append(torch.stack(head_outputs))
contextual_tokens = torch.stack(contextual)
# (B, T, L, D)
# contextual_tokens = self.out_proj(contextual_tokens)
# === NEW: EDU-level Transformer ===
edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
encoded_edu = self.transformer(edu_tokens) # (B*T, L, D)
encoded = encoded_edu.view(B, T, L, self.d_model) # (B, T, L, D)
return encoded, 0
class UtteranceEmbedings(PreTrainedModel):
config_class = SAUTEConfig
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.saute_unit = EDUSpeakerAwareMLM(config)
self.config = config
self.init_weights()
def forward(
self,
input_ids : torch.Tensor,
speaker_names : list[str],
attention_mask : torch.Tensor = None,
labels : torch.Tensor = None
):
# print(input_ids.shape)
X, flop_penalty = self.saute_unit.forward(
input_ids = input_ids,
speaker_names = speaker_names,
attention_mask = attention_mask,
# hidden_state = None
)
# print(X.shape)
logits = self.lm_head(X)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + 1e-3 * flop_penalty
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return MaskedLMOutput(loss=loss, logits=logits)