JustinDuc commited on
Commit
621ae60
·
verified ·
1 Parent(s): 3ed12c3

Upload UtteranceEmbedings

Browse files
Files changed (4) hide show
  1. README.md +8 -8
  2. config.json +2 -2
  3. model.safetensors +2 -2
  4. saute_model.py +57 -15
README.md CHANGED
@@ -1,16 +1,16 @@
1
  ---
2
  license: mit
3
  tags:
4
- - masked-language-modeling
5
- - dialogue
6
- - speaker-aware
7
- - transformer
8
- - saute
9
- - pytorch
10
  datasets:
11
- - SODA
12
  language:
13
- - en
14
  pipeline_tag: fill-mask
15
  model_type: saute
16
  library_name: transformers
 
1
  ---
2
  license: mit
3
  tags:
4
+ - masked-language-modeling
5
+ - dialogue
6
+ - speaker-aware
7
+ - transformer
8
+ - saute
9
+ - pytorch
10
  datasets:
11
+ - SODA
12
  language:
13
+ - en
14
  pipeline_tag: fill-mask
15
  model_type: saute
16
  library_name: transformers
config.json CHANGED
@@ -15,9 +15,9 @@
15
  "max_position_embeddings": 512,
16
  "max_speakers": 200,
17
  "model_type": "saute",
18
- "num_attention_heads": 1,
19
  "num_edu_layers": 2,
20
- "num_hidden_layers": 1,
21
  "num_speaker_embeddings": 512,
22
  "num_token_layers": 2,
23
  "speaker_embeddings_size": 768,
 
15
  "max_position_embeddings": 512,
16
  "max_speakers": 200,
17
  "model_type": "saute",
18
+ "num_attention_heads": 8,
19
  "num_edu_layers": 2,
20
+ "num_hidden_layers": 3,
21
  "num_speaker_embeddings": 512,
22
  "num_token_layers": 2,
23
  "speaker_embeddings_size": 768,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9406a034ce4cc90e25074e183198a7068a67ba1b3b465e94975252138ac19656
3
- size 560983656
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e2ee7cabbb652ec8f13c95a48b0336362ec8b7d4698ca6fefb515229d39a898
3
+ size 605098400
saute_model.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, BertModel, BertTokenizerFast
4
  from transformers.modeling_outputs import MaskedLMOutput
5
- from .saute_config import SAUTEConfig
6
 
7
  activation_to_class = {
8
  "gelu" : nn.GELU,
@@ -23,12 +23,20 @@ class EDUSpeakerAwareMLM(nn.Module):
23
  param.requires_grad = False # frozen encoder
24
 
25
  self.d_model = config.hidden_size
26
- self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
27
- self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
 
 
 
 
 
 
 
28
  self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
29
 
30
  encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
31
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
 
32
 
33
  # self.mlp_proj = nn.Sequential(
34
  # nn.Linear(config.hidden_size, 2048),
@@ -59,12 +67,18 @@ class EDUSpeakerAwareMLM(nn.Module):
59
  token_embeddings = outputs.last_hidden_state # (B*T, L, D)
60
 
61
  token_embeddings = token_embeddings.view(B, T, L, self.d_model)
62
- edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
63
- query_emb = self.query_proj(token_embeddings)
 
64
 
65
  # Speaker-aware memory
66
  speaker_memories = [{} for _ in range(B)]
67
- speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
 
 
 
 
 
68
 
69
  for b in range(B):
70
  for t in range(T):
@@ -72,15 +86,22 @@ class EDUSpeakerAwareMLM(nn.Module):
72
  e_t = edu_embeddings[b, t] # (D)
73
 
74
  if speaker not in speaker_memories[b]:
 
 
 
 
75
  speaker_memories[b][speaker] = {
76
- 'kv_sum': torch.zeros(self.d_model, self.d_model, device=e_t.device),
77
- # 'k_sum': torch.zeros(self.d_model, device=e_t.device),
78
  }
79
 
80
  mem = speaker_memories[b][speaker]
81
- k_t = self.key_proj(e_t)
82
- v_t = self.val_proj(e_t)
83
- kv_t = torch.outer(k_t, v_t)
 
 
 
 
84
 
85
  # with torch.no_grad():
86
  mem['kv_sum'] = mem['kv_sum'] + kv_t
@@ -93,11 +114,32 @@ class EDUSpeakerAwareMLM(nn.Module):
93
  speaker_matrices[b, t] = mem['kv_sum']
94
 
95
  # Apply speaker matrix to each token
96
- speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
97
- token_embeddings_exp = query_emb.unsqueeze(-1) # (B, T, L, D, 1)
98
- contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
99
  # contextual_tokens = self.ln1(contextual_tokens)
100
  # contextual_tokens = self.ln2(contextual_tokens + self.mlp_proj(contextual_tokens))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # === NEW: EDU-level Transformer ===
103
  edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
@@ -144,4 +186,4 @@ class UtteranceEmbedings(PreTrainedModel):
144
  # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + 1e-3 * flop_penalty
145
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
146
 
147
- return MaskedLMOutput(loss=loss, logits=logits)
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, BertModel, BertTokenizerFast
4
  from transformers.modeling_outputs import MaskedLMOutput
5
+ from sources.saute_config import SAUTEConfig
6
 
7
  activation_to_class = {
8
  "gelu" : nn.GELU,
 
23
  param.requires_grad = False # frozen encoder
24
 
25
  self.d_model = config.hidden_size
26
+ # self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
27
+ # self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
28
+ self.num_heads = config.num_attention_heads
29
+ self.head_dim = config.hidden_size // self.num_heads
30
+
31
+ self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
32
+ self.val_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
33
+ self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
34
+
35
  self.query_proj = nn.Linear(config.hidden_size, config.hidden_size, bias = False)
36
 
37
  encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads, batch_first=True)
38
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
39
+ # self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
40
 
41
  # self.mlp_proj = nn.Sequential(
42
  # nn.Linear(config.hidden_size, 2048),
 
67
  token_embeddings = outputs.last_hidden_state # (B*T, L, D)
68
 
69
  token_embeddings = token_embeddings.view(B, T, L, self.d_model)
70
+ # edu_embeddings = token_embeddings.mean(dim=2) # (B, T, D)
71
+ edu_embeddings = token_embeddings[:,:,0] # CLS token
72
+ # query_emb = self.query_proj(token_embeddings)
73
 
74
  # Speaker-aware memory
75
  speaker_memories = [{} for _ in range(B)]
76
+ # speaker_matrices = torch.zeros(B, T, self.d_model, self.d_model, device=edu_embeddings.device)
77
+ H = self.num_heads
78
+ d = self.head_dim
79
+
80
+ speaker_matrices = torch.zeros(B, T, H, d, d, device=edu_embeddings.device)
81
+
82
 
83
  for b in range(B):
84
  for t in range(T):
 
86
  e_t = edu_embeddings[b, t] # (D)
87
 
88
  if speaker not in speaker_memories[b]:
89
+ # speaker_memories[b][speaker] = {
90
+ # 'kv_sum': torch.zeros(self.d_model, self.d_model, device=e_t.device),
91
+ # # 'k_sum': torch.zeros(self.d_model, device=e_t.device),
92
+ # }
93
  speaker_memories[b][speaker] = {
94
+ 'kv_sum': torch.zeros(self.num_heads, self.head_dim, self.head_dim, device=e_t.device)
 
95
  }
96
 
97
  mem = speaker_memories[b][speaker]
98
+ # k_t = self.key_proj(e_t)
99
+ # v_t = self.val_proj(e_t)
100
+ # kv_t = torch.outer(k_t, v_t)
101
+ k_t = self.key_proj(e_t).view(self.num_heads, self.head_dim) # (H, d_k)
102
+ v_t = self.val_proj(e_t).view(self.num_heads, self.head_dim) # (H, d_v)
103
+ kv_t = torch.einsum("hd,he->hde", k_t, v_t) # (H, d_k, d_v)
104
+
105
 
106
  # with torch.no_grad():
107
  mem['kv_sum'] = mem['kv_sum'] + kv_t
 
114
  speaker_matrices[b, t] = mem['kv_sum']
115
 
116
  # Apply speaker matrix to each token
117
+ # speaker_matrices_exp = speaker_matrices.unsqueeze(2) # (B, T, 1, D, D)
118
+ # token_embeddings_exp = query_emb.unsqueeze(-1) # (B, T, L, D, 1)
119
+ # contextual_tokens = token_embeddings + torch.matmul(speaker_matrices_exp, token_embeddings_exp).squeeze(-1) # (B, T, L, D)
120
  # contextual_tokens = self.ln1(contextual_tokens)
121
  # contextual_tokens = self.ln2(contextual_tokens + self.mlp_proj(contextual_tokens))
122
+ # Project queries
123
+ query_emb = self.query_proj(token_embeddings) # (B, T, L, D)
124
+ query = query_emb.view(B, T, L, H, d) # (B, T, L, H, d)
125
+
126
+ # Apply memory matrices
127
+ contextual = []
128
+ for b in range(B):
129
+ head_outputs = []
130
+ for t in range(T):
131
+ speaker = speaker_names[b][t]
132
+ M = speaker_matrices[b, t] # (H, d, d)
133
+ q = query[b, t] # (L, H, d)
134
+ q = q.transpose(0, 1) # (H, L, d)
135
+ a = torch.matmul(q, M) # (H, L, d)
136
+ a = a.transpose(0, 1).contiguous().view(L, -1) # (L, D)
137
+ contextual_token = token_embeddings[b, t] + a
138
+ head_outputs.append(contextual_token)
139
+ contextual.append(torch.stack(head_outputs))
140
+ contextual_tokens = torch.stack(contextual)
141
+ # (B, T, L, D)
142
+ # contextual_tokens = self.out_proj(contextual_tokens)
143
 
144
  # === NEW: EDU-level Transformer ===
145
  edu_tokens = contextual_tokens.view(B * T, L, self.d_model) # (B*T, L, D)
 
186
  # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + 1e-3 * flop_penalty
187
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
188
 
189
+ return MaskedLMOutput(loss=loss, logits=logits)