Upload 10 files
Browse files- utils/IPA_lo_dict +0 -0
- utils/IPA_sim_statistic_analysis.py +96 -0
- utils/IPA_th_dict +0 -0
- utils/__pycache__/acc_and_f1.cpython-38.pyc +0 -0
- utils/__pycache__/acc_and_f1.cpython-39.pyc +0 -0
- utils/__pycache__/attn.cpython-38.pyc +0 -0
- utils/__pycache__/attn.cpython-39.pyc +0 -0
- utils/acc_and_f1.py +19 -0
- utils/attn.py +266 -0
- utils/same_list +0 -0
utils/IPA_lo_dict
ADDED
|
Binary file (78.4 kB). View file
|
|
|
utils/IPA_sim_statistic_analysis.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import epitran
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import pickle as pkl
|
| 4 |
+
|
| 5 |
+
''' 统计分析。分词,利用分词结果,做统计分析和构建音标词典 '''
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def analyse_by_IPA_statistic(file_lo, file_th, statistic_conclusion_exist=False):
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
if statistic_conclusion_exist:
|
| 11 |
+
IPA_lo_dict = pkl.load(open('IPA_lo_dict', 'rb'))
|
| 12 |
+
IPA_th_dict = pkl.load(open('IPA_th_dict', 'rb'))
|
| 13 |
+
IPA_lo_dict_cop = IPA_lo_dict.copy()
|
| 14 |
+
IPA_th_dict_cop = IPA_th_dict.copy()
|
| 15 |
+
for key_ in IPA_th_dict:
|
| 16 |
+
for i in key_:
|
| 17 |
+
if i.isdigit():
|
| 18 |
+
del IPA_th_dict_cop[key_]
|
| 19 |
+
break
|
| 20 |
+
|
| 21 |
+
for key_ in IPA_lo_dict:
|
| 22 |
+
for i in key_:
|
| 23 |
+
if i.isdigit():
|
| 24 |
+
del IPA_lo_dict_cop[key_]
|
| 25 |
+
break
|
| 26 |
+
sorted_IPA_lo_tp = sorted(IPA_th_dict_cop.items(), key=lambda x: x[1], reverse=True)
|
| 27 |
+
sorted_IPA_th_tp = sorted(IPA_lo_dict_cop.items(), key=lambda x: x[1], reverse=True)
|
| 28 |
+
sorted_IPA_lo = [t[0] for t in sorted_IPA_lo_tp]
|
| 29 |
+
sorted_IPA_th = [t[0] for t in sorted_IPA_th_tp]
|
| 30 |
+
same_list = []
|
| 31 |
+
for idx, i in enumerate(sorted_IPA_lo):
|
| 32 |
+
if i in sorted_IPA_th:
|
| 33 |
+
'''
|
| 34 |
+
如果IPA_th,IPA_lo有相同元素,获取该元素的值
|
| 35 |
+
'''
|
| 36 |
+
same_list.append([i, idx, sorted_IPA_th.index(i), IPA_lo_dict[i], IPA_th_dict[i]])
|
| 37 |
+
|
| 38 |
+
pkl.dump(same_list, open('same_list', 'wb'))
|
| 39 |
+
return
|
| 40 |
+
else:
|
| 41 |
+
plm_tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
+
r'../foundation/E5')
|
| 43 |
+
|
| 44 |
+
with open(file_lo, 'r', encoding='utf-8') as f:
|
| 45 |
+
data_lo = f.readlines()
|
| 46 |
+
with open(file_th, 'r', encoding='utf-8') as f:
|
| 47 |
+
data_th = f.readlines()
|
| 48 |
+
|
| 49 |
+
IPA_lo_dict = {}
|
| 50 |
+
IPA_th_dict = {}
|
| 51 |
+
print(len(data_lo))
|
| 52 |
+
print(len(data_th))
|
| 53 |
+
|
| 54 |
+
for i, j in tqdm(zip(data_lo, data_th)):
|
| 55 |
+
input_lo = i
|
| 56 |
+
input_th = j
|
| 57 |
+
tked_lo = \
|
| 58 |
+
plm_tokenizer(input_lo, max_length=512, padding=True, truncation=True, return_tensors='pt').encodings[
|
| 59 |
+
0].tokens[2:-1]
|
| 60 |
+
tked_th = \
|
| 61 |
+
plm_tokenizer(input_th, max_length=512, padding=True, truncation=True, return_tensors='pt').encodings[
|
| 62 |
+
0].tokens[2:-1]
|
| 63 |
+
epi_lo = epitran.Epitran("lao-Laoo")
|
| 64 |
+
epi_th = epitran.Epitran("tha-Thai")
|
| 65 |
+
|
| 66 |
+
for i in tked_lo:
|
| 67 |
+
IPA_lo = epi_lo.transliterate(i)
|
| 68 |
+
IPA_lo_dict[IPA_lo] = IPA_lo_dict.get(IPA_lo, 1) + 1
|
| 69 |
+
for j in tked_th:
|
| 70 |
+
IPA_th = epi_th.transliterate(j)
|
| 71 |
+
IPA_th_dict[IPA_th] = IPA_th_dict.get(IPA_th, 1) + 1
|
| 72 |
+
|
| 73 |
+
pkl.dump(IPA_lo_dict, open('IPA_lo_dict', 'wb'))
|
| 74 |
+
pkl.dump(IPA_th_dict, open('IPA_th_dict', 'wb'))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def spliteKeyWord(in_str):
|
| 78 |
+
# print(in_str)
|
| 79 |
+
# in_str.replace('/([0-9]+)/g', '')
|
| 80 |
+
return set(list(in_str))4
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def minhash(str_a, str_b): # 相似度计算 0-1
|
| 84 |
+
score = 0.0
|
| 85 |
+
jaccard_distance = lambda seta, setb: len(seta & setb) / float(len(seta | setb))
|
| 86 |
+
try:
|
| 87 |
+
score = jaccard_distance(spliteKeyWord(str_a), spliteKeyWord(str_b))
|
| 88 |
+
except ZeroDivisionError:
|
| 89 |
+
print('ZeroDivisionError')
|
| 90 |
+
|
| 91 |
+
return score
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
analyse_by_IPA_statistic('../data/triple/data_lo.txt', '../data/triple/data_th.txt',
|
| 96 |
+
statistic_conclusion_exist=False)
|
utils/IPA_th_dict
ADDED
|
Binary file (157 kB). View file
|
|
|
utils/__pycache__/acc_and_f1.cpython-38.pyc
ADDED
|
Binary file (710 Bytes). View file
|
|
|
utils/__pycache__/acc_and_f1.cpython-39.pyc
ADDED
|
Binary file (701 Bytes). View file
|
|
|
utils/__pycache__/attn.cpython-38.pyc
ADDED
|
Binary file (5.81 kB). View file
|
|
|
utils/__pycache__/attn.cpython-39.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
utils/acc_and_f1.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.metrics import f1_score, precision_score, recall_score
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def cal_acc_and_f1(preds, labels):
|
| 5 |
+
acc = simple_accuracy(preds, labels)
|
| 6 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
| 7 |
+
prec = precision_score(y_true=labels, y_pred=preds)
|
| 8 |
+
reca = recall_score(y_true=labels, y_pred=preds)
|
| 9 |
+
return {
|
| 10 |
+
"acc": acc,
|
| 11 |
+
"precision": prec,
|
| 12 |
+
"recall": reca,
|
| 13 |
+
"f1": f1,
|
| 14 |
+
"acc_and_f1": (acc + f1) / 2,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def simple_accuracy(preds, labels):
|
| 19 |
+
return (preds == labels).mean()
|
utils/attn.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Multi-Head Attention module """
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiHeadedAttention(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Multi-Head Attention module from
|
| 11 |
+
"Attention is All You Need"
|
| 12 |
+
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
|
| 13 |
+
|
| 14 |
+
Similar to standard `dot` attention but uses
|
| 15 |
+
multiple attention distributions simulataneously
|
| 16 |
+
to select relevant items.
|
| 17 |
+
|
| 18 |
+
.. mermaid::
|
| 19 |
+
|
| 20 |
+
graph BT
|
| 21 |
+
A[key]
|
| 22 |
+
B[value]
|
| 23 |
+
C[query]
|
| 24 |
+
O[output]
|
| 25 |
+
subgraph Attn
|
| 26 |
+
D[Attn 1]
|
| 27 |
+
E[Attn 2]
|
| 28 |
+
F[Attn N]
|
| 29 |
+
end
|
| 30 |
+
A --> D
|
| 31 |
+
C --> D
|
| 32 |
+
A --> E
|
| 33 |
+
C --> E
|
| 34 |
+
A --> F
|
| 35 |
+
C --> F
|
| 36 |
+
D --> O
|
| 37 |
+
E --> O
|
| 38 |
+
F --> O
|
| 39 |
+
B --> O
|
| 40 |
+
|
| 41 |
+
Also includes several additional tricks.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
head_count (int): number of parallel heads
|
| 45 |
+
model_dim (int): the dimension of keys/values/queries,
|
| 46 |
+
must be divisible by head_count
|
| 47 |
+
dropout (float): dropout parameter
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
|
| 51 |
+
assert model_dim % head_count == 0
|
| 52 |
+
self.dim_per_head = model_dim // head_count
|
| 53 |
+
self.model_dim = model_dim
|
| 54 |
+
|
| 55 |
+
super(MultiHeadedAttention, self).__init__()
|
| 56 |
+
self.head_count = head_count
|
| 57 |
+
|
| 58 |
+
self.linear_keys = nn.Linear(model_dim,
|
| 59 |
+
head_count * self.dim_per_head)
|
| 60 |
+
self.linear_values = nn.Linear(model_dim,
|
| 61 |
+
head_count * self.dim_per_head)
|
| 62 |
+
self.linear_query = nn.Linear(model_dim,
|
| 63 |
+
head_count * self.dim_per_head)
|
| 64 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
self.use_final_linear = use_final_linear
|
| 67 |
+
if(self.use_final_linear):
|
| 68 |
+
self.final_linear = nn.Linear(model_dim, model_dim)
|
| 69 |
+
|
| 70 |
+
def forward(self, key, value, query, mask=None,
|
| 71 |
+
layer_cache=None, type=None):
|
| 72 |
+
"""
|
| 73 |
+
Compute the context vector and the attention vectors.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
key (`FloatTensor`): set of `key_len`
|
| 77 |
+
key vectors `[batch, key_len, dim]`
|
| 78 |
+
value (`FloatTensor`): set of `key_len`
|
| 79 |
+
value vectors `[batch, key_len, dim]`
|
| 80 |
+
query (`FloatTensor`): set of `query_len`
|
| 81 |
+
query vectors `[batch, query_len, dim]`
|
| 82 |
+
mask: binary mask indicating which keys have
|
| 83 |
+
non-zero attention `[batch, query_len, key_len]`
|
| 84 |
+
Returns:
|
| 85 |
+
(`FloatTensor`, `FloatTensor`) :
|
| 86 |
+
|
| 87 |
+
* output context vectors `[batch, query_len, dim]`
|
| 88 |
+
* one of the attention vectors `[batch, query_len, key_len]`
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# CHECKS
|
| 92 |
+
# batch, k_len, d = key.size()
|
| 93 |
+
# batch_, k_len_, d_ = value.size()
|
| 94 |
+
# aeq(batch, batch_)
|
| 95 |
+
# aeq(k_len, k_len_)
|
| 96 |
+
# aeq(d, d_)
|
| 97 |
+
# batch_, q_len, d_ = query.size()
|
| 98 |
+
# aeq(batch, batch_)
|
| 99 |
+
# aeq(d, d_)
|
| 100 |
+
# aeq(self.model_dim % 8, 0)
|
| 101 |
+
# if mask is not None:
|
| 102 |
+
# batch_, q_len_, k_len_ = mask.size()
|
| 103 |
+
# aeq(batch_, batch)
|
| 104 |
+
# aeq(k_len_, k_len)
|
| 105 |
+
# aeq(q_len_ == q_len)
|
| 106 |
+
# END CHECKS
|
| 107 |
+
|
| 108 |
+
batch_size = key.size(0)
|
| 109 |
+
dim_per_head = self.dim_per_head
|
| 110 |
+
head_count = self.head_count
|
| 111 |
+
key_len = key.size(1)
|
| 112 |
+
query_len = query.size(1)
|
| 113 |
+
|
| 114 |
+
def shape(x):
|
| 115 |
+
""" projection """
|
| 116 |
+
return x.view(batch_size, -1, head_count, dim_per_head) \
|
| 117 |
+
.transpose(1, 2)
|
| 118 |
+
|
| 119 |
+
def unshape(x):
|
| 120 |
+
""" compute context """
|
| 121 |
+
return x.transpose(1, 2).contiguous() \
|
| 122 |
+
.view(batch_size, -1, head_count * dim_per_head)
|
| 123 |
+
|
| 124 |
+
# 1) Project key, value, and query.
|
| 125 |
+
if layer_cache is not None:
|
| 126 |
+
if type == "self":
|
| 127 |
+
query, key, value = self.linear_query(query),\
|
| 128 |
+
self.linear_keys(query),\
|
| 129 |
+
self.linear_values(query)
|
| 130 |
+
|
| 131 |
+
key = shape(key)
|
| 132 |
+
value = shape(value)
|
| 133 |
+
|
| 134 |
+
if layer_cache is not None:
|
| 135 |
+
device = key.device
|
| 136 |
+
if layer_cache["self_keys"] is not None:
|
| 137 |
+
key = torch.cat(
|
| 138 |
+
(layer_cache["self_keys"].to(device), key),
|
| 139 |
+
dim=2)
|
| 140 |
+
if layer_cache["self_values"] is not None:
|
| 141 |
+
value = torch.cat(
|
| 142 |
+
(layer_cache["self_values"].to(device), value),
|
| 143 |
+
dim=2)
|
| 144 |
+
layer_cache["self_keys"] = key
|
| 145 |
+
layer_cache["self_values"] = value
|
| 146 |
+
elif type == "context":
|
| 147 |
+
query = self.linear_query(query)
|
| 148 |
+
if layer_cache is not None:
|
| 149 |
+
if layer_cache["memory_keys"] is None:
|
| 150 |
+
key, value = self.linear_keys(key),\
|
| 151 |
+
self.linear_values(value)
|
| 152 |
+
key = shape(key)
|
| 153 |
+
value = shape(value)
|
| 154 |
+
else:
|
| 155 |
+
key, value = layer_cache["memory_keys"],\
|
| 156 |
+
layer_cache["memory_values"]
|
| 157 |
+
layer_cache["memory_keys"] = key
|
| 158 |
+
layer_cache["memory_values"] = value
|
| 159 |
+
else:
|
| 160 |
+
key, value = self.linear_keys(key),\
|
| 161 |
+
self.linear_values(value)
|
| 162 |
+
key = shape(key)
|
| 163 |
+
value = shape(value)
|
| 164 |
+
else:
|
| 165 |
+
key = self.linear_keys(key)
|
| 166 |
+
value = self.linear_values(value)
|
| 167 |
+
query = self.linear_query(query)
|
| 168 |
+
key = shape(key)
|
| 169 |
+
value = shape(value)
|
| 170 |
+
|
| 171 |
+
query = shape(query)
|
| 172 |
+
|
| 173 |
+
key_len = key.size(2)
|
| 174 |
+
query_len = query.size(2)
|
| 175 |
+
|
| 176 |
+
# 2) Calculate and scale scores.
|
| 177 |
+
query = query / math.sqrt(dim_per_head)
|
| 178 |
+
scores = torch.matmul(query, key.transpose(2, 3))
|
| 179 |
+
|
| 180 |
+
if mask is not None:
|
| 181 |
+
mask = mask.unsqueeze(1).expand_as(scores)
|
| 182 |
+
scores = scores.masked_fill(mask, -1e18)
|
| 183 |
+
|
| 184 |
+
# 3) Apply attention dropout and compute context vectors.
|
| 185 |
+
|
| 186 |
+
attn = self.softmax(scores)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
drop_attn = self.dropout(attn)
|
| 190 |
+
if(self.use_final_linear):
|
| 191 |
+
context = unshape(torch.matmul(drop_attn, value))
|
| 192 |
+
output = self.final_linear(context)
|
| 193 |
+
return output
|
| 194 |
+
else:
|
| 195 |
+
context = torch.matmul(drop_attn, value)
|
| 196 |
+
return context
|
| 197 |
+
|
| 198 |
+
# CHECK
|
| 199 |
+
# batch_, q_len_, d_ = output.size()
|
| 200 |
+
# aeq(q_len, q_len_)
|
| 201 |
+
# aeq(batch, batch_)
|
| 202 |
+
# aeq(d, d_)
|
| 203 |
+
|
| 204 |
+
# Return one attn
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class MultiHeadedPooling(nn.Module):
|
| 210 |
+
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
|
| 211 |
+
assert model_dim % head_count == 0
|
| 212 |
+
self.dim_per_head = model_dim // head_count
|
| 213 |
+
self.model_dim = model_dim
|
| 214 |
+
super(MultiHeadedPooling, self).__init__()
|
| 215 |
+
self.head_count = head_count
|
| 216 |
+
self.linear_keys = nn.Linear(model_dim,
|
| 217 |
+
head_count)
|
| 218 |
+
self.linear_values = nn.Linear(model_dim,
|
| 219 |
+
head_count * self.dim_per_head)
|
| 220 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 221 |
+
self.dropout = nn.Dropout(dropout)
|
| 222 |
+
if (use_final_linear):
|
| 223 |
+
self.final_linear = nn.Linear(model_dim, model_dim)
|
| 224 |
+
self.use_final_linear = use_final_linear
|
| 225 |
+
|
| 226 |
+
def forward(self, key, value, mask=None):
|
| 227 |
+
batch_size = key.size(0)
|
| 228 |
+
dim_per_head = self.dim_per_head
|
| 229 |
+
head_count = self.head_count
|
| 230 |
+
|
| 231 |
+
def shape(x, dim=dim_per_head):
|
| 232 |
+
""" projection """
|
| 233 |
+
return x.view(batch_size, -1, head_count, dim) \
|
| 234 |
+
.transpose(1, 2)
|
| 235 |
+
|
| 236 |
+
def unshape(x, dim=dim_per_head):
|
| 237 |
+
""" compute context """
|
| 238 |
+
return x.transpose(1, 2).contiguous() \
|
| 239 |
+
.view(batch_size, -1, head_count * dim)
|
| 240 |
+
|
| 241 |
+
scores = self.linear_keys(key)
|
| 242 |
+
value = self.linear_values(value)
|
| 243 |
+
|
| 244 |
+
scores = shape(scores, 1).squeeze(-1)
|
| 245 |
+
value = shape(value)
|
| 246 |
+
# key_len = key.size(2)
|
| 247 |
+
# query_len = query.size(2)
|
| 248 |
+
#
|
| 249 |
+
# scores = torch.matmul(query, key.transpose(2, 3))
|
| 250 |
+
|
| 251 |
+
if mask is not None:
|
| 252 |
+
mask = mask.unsqueeze(1).expand_as(scores)
|
| 253 |
+
scores = scores.masked_fill(mask, -1e18)
|
| 254 |
+
|
| 255 |
+
# 3) Apply attention dropout and compute context vectors.
|
| 256 |
+
attn = self.softmax(scores)
|
| 257 |
+
drop_attn = self.dropout(attn)
|
| 258 |
+
context = torch.sum((drop_attn.unsqueeze(-1) * value), -2)
|
| 259 |
+
if (self.use_final_linear):
|
| 260 |
+
context = unshape(context).squeeze(1)
|
| 261 |
+
output = self.final_linear(context)
|
| 262 |
+
return output
|
| 263 |
+
else:
|
| 264 |
+
return context
|
| 265 |
+
|
| 266 |
+
|
utils/same_list
ADDED
|
Binary file (41.7 kB). View file
|
|
|