JustinDuc commited on
Commit
d8becc1
·
verified ·
1 Parent(s): 04e8356

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -36
model.py CHANGED
@@ -1,8 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import PreTrainedModel, AutoModel
4
  from transformers.modeling_outputs import MaskedLMOutput
5
- from sources.saute_config import SAUTEConfig
 
 
 
 
 
 
 
 
6
 
7
  class EDUSpeakerAwareMLM(nn.Module):
8
  def __init__(self, config):
@@ -15,12 +23,25 @@ class EDUSpeakerAwareMLM(nn.Module):
15
  param.requires_grad = False # frozen encoder
16
 
17
  self.d_model = config.hidden_size
 
 
18
  self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
19
 
20
  encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
21
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
22
 
23
- self.saute = SAUTE(config)
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def forward(self, input_ids, attention_mask, speaker_names):
26
  """
@@ -39,41 +60,11 @@ class EDUSpeakerAwareMLM(nn.Module):
39
 
40
  token_embeddings = token_embeddings.view(B, T, L, self.d_model)
41
  edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
42
-
43
- contextual_tokens = self.saute(input_ids, speaker_names, token_embeddings, edu_embeddings)
44
-
45
- # === NEW: EDU-level Transformer ===
46
- edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
47
- encoded_edu = self.transformer(edu_tokens) # (B*T, L, D)
48
- encoded = encoded_edu.view(B, T, L, self.d_model) # (B, T, L, D)
49
-
50
- return encoded, 0
51
-
52
- class SAUTE(nn.Module):
53
- def __init__(self,
54
- config : SAUTEConfig
55
- ):
56
- super().__init__()
57
-
58
- self.d_model = config.hidden_size
59
-
60
- self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
61
- self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
62
- self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
63
-
64
- def forward(self,
65
- input_ids : torch.Tensor,
66
- speaker_names : list[str],
67
- token_embeddings : torch.Tensor,
68
- edu_embeddings : torch.Tensor
69
- ):
70
 
71
  # Speaker-aware memory
72
- B, T, L = input_ids.shape
73
-
74
  speaker_memories = [{} for _ in range(B)]
75
  speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
76
- query_embeddings = self.query_proj(token_embeddings)
77
 
78
  for b in range(B):
79
  for t in range(T):
@@ -103,10 +94,18 @@ class SAUTE(nn.Module):
103
 
104
  # Apply speaker matrix to each token
105
  speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
106
- token_embeddings_exp = query_embeddings.unsqueeze(-1) # (B, T, L, D, 1)
107
  contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
 
 
 
 
 
 
 
 
 
108
 
109
- return contextual_tokens
110
 
111
  class UtteranceEmbedings(PreTrainedModel):
112
  config_class = SAUTEConfig
@@ -135,6 +134,7 @@ class UtteranceEmbedings(PreTrainedModel):
135
  attention_mask = attention_mask,
136
  # hidden_state = None
137
  )
 
138
 
139
  logits = self.lm_head(X)
140
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel, BertModel, BertTokenizerFast
4
  from transformers.modeling_outputs import MaskedLMOutput
5
+ from saute_config import SAUTEConfig
6
+
7
+ activation_to_class = {
8
+ "gelu" : nn.GELU,
9
+ "relu" : nn.ReLU,
10
+ "sigmoid" : nn.Sigmoid
11
+ }
12
+
13
+ from transformers import AutoModel
14
 
15
  class EDUSpeakerAwareMLM(nn.Module):
16
  def __init__(self, config):
 
23
  param.requires_grad = False # frozen encoder
24
 
25
  self.d_model = config.hidden_size
26
+ self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
27
+ self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
28
  self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
29
 
30
  encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
31
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
32
 
33
+ # self.mlp_proj = nn.Sequential(
34
+ # nn.Linear(config.hidden_size, 2048),
35
+ # activation_to_class["gelu"](),
36
+ # # nn.Dropout(0.1),
37
+ # nn.Linear(2048, config.hidden_size),
38
+ # # nn.Dropout(0.1),
39
+ # )
40
+ self.ln1 = nn.LayerNorm(config.hidden_size)
41
+ # self.ln2 = nn.LayerNorm(config.hidden_size)
42
+
43
+ # self.speaker_memory = {} # Will be filled per batch
44
+ # self.lm_head = nn.Linear(config.hidden_size, self.edu_encoder.config.vocab_size)
45
 
46
  def forward(self, input_ids, attention_mask, speaker_names):
47
  """
 
60
 
61
  token_embeddings = token_embeddings.view(B, T, L, self.d_model)
62
  edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
63
+ query_emb = self.query_proj(token_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Speaker-aware memory
 
 
66
  speaker_memories = [{} for _ in range(B)]
67
  speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
 
68
 
69
  for b in range(B):
70
  for t in range(T):
 
94
 
95
  # Apply speaker matrix to each token
96
  speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
97
+ token_embeddings_exp = query_emb.unsqueeze(-1) # (B, T, L, D, 1)
98
  contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
99
+ # contextual_tokens = self.ln1(contextual_tokens)
100
+ # contextual_tokens = self.ln2(contextual_tokens + self.mlp_proj(contextual_tokens))
101
+
102
+ # === NEW: EDU-level Transformer ===
103
+ edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
104
+ encoded_edu = self.transformer(edu_tokens) # (B*T, L, D)
105
+ encoded = encoded_edu.view(B, T, L, self.d_model) # (B, T, L, D)
106
+
107
+ return encoded, 0
108
 
 
109
 
110
  class UtteranceEmbedings(PreTrainedModel):
111
  config_class = SAUTEConfig
 
134
  attention_mask = attention_mask,
135
  # hidden_state = None
136
  )
137
+ # print(X.shape)
138
 
139
  logits = self.lm_head(X)
140