JustinDuc commited on
Commit
07100cc
·
verified ·
1 Parent(s): 8a0239b

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +147 -0
model.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoModel
4
+ from transformers.modeling_outputs import MaskedLMOutput
5
+ from sources.saute_config import SAUTEConfig
6
+
7
+ class EDUSpeakerAwareMLM(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ # model_name="sentence-transformers/all-MiniLM-L6-v2"
11
+ model_name = "bert-base-uncased"
12
+
13
+ self.edu_encoder = AutoModel.from_pretrained(model_name)
14
+ for param in self.edu_encoder.parameters():
15
+ param.requires_grad = False # frozen encoder
16
+
17
+ self.d_model = config.hidden_size
18
+ self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
19
+
20
+ encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
21
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
22
+
23
+ self.saute = SAUTE(config)
24
+
25
+ def forward(self, input_ids, attention_mask, speaker_names):
26
+ """
27
+ input_ids: (B, T, L)
28
+ attention_mask: (B, T, L)
29
+ speaker_names: list of list of strings, shape (B, T)
30
+ """
31
+ B, T, L = input_ids.shape
32
+
33
+ # Encode EDUs using frozen encoder
34
+ with torch.no_grad():
35
+ input_ids_flat = input_ids.view(B * T, L)
36
+ attention_mask_flat = attention_mask.view(B * T, L)
37
+ outputs = self.edu_encoder(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
38
+ token_embeddings = outputs.last_hidden_state # (B*T, L, D)
39
+
40
+ token_embeddings = token_embeddings.view(B, T, L, self.d_model)
41
+ edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
42
+
43
+ contextual_tokens = self.saute(input_ids, speaker_names, token_embeddings, edu_embeddings)
44
+
45
+ # === NEW: EDU-level Transformer ===
46
+ edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
47
+ encoded_edu = self.transformer(edu_tokens) # (B*T, L, D)
48
+ encoded = encoded_edu.view(B, T, L, self.d_model) # (B, T, L, D)
49
+
50
+ return encoded, 0
51
+
52
+ class SAUTE(nn.Module):
53
+ def __init__(self,
54
+ config : SAUTEConfig
55
+ ):
56
+ super().__init__()
57
+
58
+ self.d_model = config.hidden_size
59
+
60
+ self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
61
+ self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
62
+ self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
63
+
64
+ def forward(self,
65
+ input_ids : torch.Tensor,
66
+ speaker_names : list[str],
67
+ token_embeddings : torch.Tensor,
68
+ edu_embeddings : torch.Tensor
69
+ ):
70
+
71
+ # Speaker-aware memory
72
+ B, T, L = input_ids.shape
73
+
74
+ speaker_memories = [{} for _ in range(B)]
75
+ speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
76
+ query_embeddings = self.query_proj(token_embeddings)
77
+
78
+ for b in range(B):
79
+ for t in range(T):
80
+ speaker = speaker_names[b][t]
81
+ e_t = edu_embeddings[b, t] # (D)
82
+
83
+ if speaker not in speaker_memories[b]:
84
+ speaker_memories[b][speaker] = {
85
+ 'kv_sum': torch.zeros(self.d_model, self.d_model, device=e_t.device),
86
+ # 'k_sum': torch.zeros(self.d_model, device=e_t.device),
87
+ }
88
+
89
+ mem = speaker_memories[b][speaker]
90
+ k_t = self.key_proj(e_t)
91
+ v_t = self.val_proj(e_t)
92
+ kv_t = torch.outer(k_t, v_t)
93
+
94
+ # with torch.no_grad():
95
+ mem['kv_sum'] = mem['kv_sum'] + kv_t
96
+ # mem['k_sum'] = mem['k_sum'] + k_t
97
+
98
+ # z = torch.clamp(mem['k_sum'] @ k_t, min=1e-6)
99
+ # M_s = mem['kv_sum'] / z # (D, D)
100
+
101
+ # speaker_matrices[b, t] = M_s
102
+ speaker_matrices[b, t] = mem['kv_sum']
103
+
104
+ # Apply speaker matrix to each token
105
+ speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
106
+ token_embeddings_exp = query_embeddings.unsqueeze(-1) # (B, T, L, D, 1)
107
+ contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
108
+
109
+ return contextual_tokens
110
+
111
+ class UtteranceEmbedings(PreTrainedModel):
112
+ config_class = SAUTEConfig
113
+
114
+ def __init__(self, config : SAUTEConfig):
115
+ super().__init__(config)
116
+
117
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
118
+ self.saute_unit = EDUSpeakerAwareMLM(config)
119
+
120
+ self.config : SAUTEConfig = config
121
+
122
+ self.init_weights()
123
+
124
+ def forward(
125
+ self,
126
+ input_ids : torch.Tensor,
127
+ speaker_names : list[str],
128
+ attention_mask : torch.Tensor = None,
129
+ labels : torch.Tensor = None
130
+ ):
131
+ # print(input_ids.shape)
132
+ X, flop_penalty = self.saute_unit.forward(
133
+ input_ids = input_ids,
134
+ speaker_names = speaker_names,
135
+ attention_mask = attention_mask,
136
+ # hidden_state = None
137
+ )
138
+
139
+ logits = self.lm_head(X)
140
+
141
+ loss = None
142
+ if labels is not None:
143
+ loss_fct = nn.CrossEntropyLoss()
144
+ # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + 1e-3 * flop_penalty
145
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
146
+
147
+ return MaskedLMOutput(loss=loss, logits=logits)