mdokl commited on
Commit
2d36faf
·
1 Parent(s): 4f04c8d

修复多用户kv_cacke共享的Bug,优化交互逻辑,新增垃圾回收

Browse files
Files changed (5) hide show
  1. Encoder.py +76 -76
  2. LazyCache.py +93 -0
  3. MultiHeadAttention.py +405 -396
  4. app.py +295 -221
  5. train_and_use.py +443 -443
Encoder.py CHANGED
@@ -1,76 +1,76 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from Affine import Affine
5
-
6
- #借来一用,简单改改
7
- class Qwen2RMSNorm(nn.Module):
8
- def __init__(self, embedding_dim, eps=1e-6):
9
- """
10
- Qwen2RMSNorm is equivalent to T5LayerNorm
11
- """
12
- super().__init__()
13
- self.weight = nn.Parameter(torch.ones(embedding_dim))
14
- self.variance_epsilon = eps
15
-
16
- def forward(self, hidden_states):
17
- # input_dtype = hidden_states.dtype
18
- # hidden_states = hidden_states.to(torch.float32)
19
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
20
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
21
- return self.weight * hidden_states#.to(input_dtype)
22
-
23
- def extra_repr(self):
24
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
25
-
26
- #针对每个词嵌入的前馈网络
27
- class PositionWiseFeedForward(nn.Module):
28
- def __init__(self,embedding_dim,feed_forward_dim,enable_affine):
29
- super(PositionWiseFeedForward, self).__init__()
30
- self.w1 = nn.Linear(embedding_dim, feed_forward_dim, bias=False)
31
- self.w2 = nn.Linear(feed_forward_dim, embedding_dim, bias=False)
32
- self.enable_affine = enable_affine
33
- if enable_affine:
34
- self.a1 = Affine(1.0)
35
- self.a2 = Affine(1.0)
36
-
37
- def forward(self, x):
38
- if self.enable_affine:
39
- x = F.relu(self.w1(self.a1(x)))
40
- return F.relu(self.w2(self.a2(x)))
41
- else:
42
- x = F.relu(self.w1(x))
43
- return F.relu(self.w2(x))
44
-
45
- #编码器层
46
- class EncoderLayer(nn.Module):
47
- def __init__(self,multi_head_attention,mask_future,position_wise_feed_forward,enable_layer_norm,dropout_rate):
48
- super(EncoderLayer,self).__init__()
49
- self.multi_head_attention = multi_head_attention
50
- self.position_wise_feed_forward = position_wise_feed_forward
51
- self.mask_future = mask_future
52
- if enable_layer_norm == True:
53
- self.layer_norm = Qwen2RMSNorm(multi_head_attention.embedding_dim)
54
- else:
55
- self.layer_norm = None
56
-
57
- self.dropout_layer = nn.Dropout(p=dropout_rate)
58
-
59
- def forward(self,query,q_mask):
60
- #绝对不能用+=,那是原地修改,没法算梯度
61
- query = query + self.dropout_layer(self.multi_head_attention(query,q_mask,query,self.mask_future))
62
- query = query + self.dropout_layer(self.position_wise_feed_forward(query))
63
- if self.layer_norm is not None:
64
- query = self.layer_norm(query)
65
- return query
66
-
67
- #编码器
68
- class Encoder(nn.Module):
69
- def __init__(self, encoder_layers):
70
- super(Encoder, self).__init__()
71
- self.encoder_layers = encoder_layers
72
-
73
- def forward(self, query, q_mask):
74
- for encoder_layer in self.encoder_layers:
75
- query = encoder_layer(query,q_mask)
76
- return query
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from Affine import Affine
5
+
6
+ #借来一用,简单改改
7
+ class Qwen2RMSNorm(nn.Module):
8
+ def __init__(self, embedding_dim, eps=1e-6):
9
+ """
10
+ Qwen2RMSNorm is equivalent to T5LayerNorm
11
+ """
12
+ super().__init__()
13
+ self.weight = nn.Parameter(torch.ones(embedding_dim))
14
+ self.variance_epsilon = eps
15
+
16
+ def forward(self, hidden_states):
17
+ # input_dtype = hidden_states.dtype
18
+ # hidden_states = hidden_states.to(torch.float32)
19
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
20
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
21
+ return self.weight * hidden_states#.to(input_dtype)
22
+
23
+ def extra_repr(self):
24
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
25
+
26
+ #针对每个词嵌入的前馈网络
27
+ class PositionWiseFeedForward(nn.Module):
28
+ def __init__(self,embedding_dim,feed_forward_dim,enable_affine):
29
+ super(PositionWiseFeedForward, self).__init__()
30
+ self.w1 = nn.Linear(embedding_dim, feed_forward_dim, bias=False)
31
+ self.w2 = nn.Linear(feed_forward_dim, embedding_dim, bias=False)
32
+ self.enable_affine = enable_affine
33
+ if enable_affine:
34
+ self.a1 = Affine(1.0)
35
+ self.a2 = Affine(1.0)
36
+
37
+ def forward(self, x):
38
+ if self.enable_affine:
39
+ x = F.relu(self.w1(self.a1(x)))
40
+ return F.relu(self.w2(self.a2(x)))
41
+ else:
42
+ x = F.relu(self.w1(x))
43
+ return F.relu(self.w2(x))
44
+
45
+ #编码器层
46
+ class EncoderLayer(nn.Module):
47
+ def __init__(self,multi_head_attention,mask_future,position_wise_feed_forward,enable_layer_norm,dropout_rate):
48
+ super(EncoderLayer,self).__init__()
49
+ self.multi_head_attention = multi_head_attention
50
+ self.position_wise_feed_forward = position_wise_feed_forward
51
+ self.mask_future = mask_future
52
+ if enable_layer_norm == True:
53
+ self.layer_norm = Qwen2RMSNorm(multi_head_attention.embedding_dim)
54
+ else:
55
+ self.layer_norm = None
56
+
57
+ self.dropout_layer = nn.Dropout(p=dropout_rate)
58
+
59
+ def forward(self,query,q_mask,session_id):
60
+ #绝对不能用+=,那是原地修改,没法算梯度
61
+ query = query + self.dropout_layer(self.multi_head_attention(query,q_mask,query,self.mask_future,session_id))
62
+ query = query + self.dropout_layer(self.position_wise_feed_forward(query))
63
+ if self.layer_norm is not None:
64
+ query = self.layer_norm(query)
65
+ return query
66
+
67
+ #编码器
68
+ class Encoder(nn.Module):
69
+ def __init__(self, encoder_layers):
70
+ super(Encoder, self).__init__()
71
+ self.encoder_layers = encoder_layers
72
+
73
+ def forward(self, query, q_mask,session_id):
74
+ for encoder_layer in self.encoder_layers:
75
+ query = encoder_layer(query,q_mask,session_id)
76
+ return query
LazyCache.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import threading
3
+ from collections import defaultdict
4
+
5
+ class ExpiringDict(dict):
6
+ """带过期时间的字典"""
7
+ def __init__(self, ttl=600, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.ttl = ttl # 秒
10
+ self._timestamps = {}
11
+ self._lock = threading.Lock()
12
+
13
+ def __setitem__(self, key, value):
14
+ with self._lock:
15
+ super().__setitem__(key, value)
16
+ self._timestamps[key] = time.time()
17
+
18
+ def __getitem__(self, key):
19
+ with self._lock:
20
+ if key in self._timestamps and (time.time() - self._timestamps[key] > self.ttl):
21
+ super().__delitem__(key)
22
+ del self._timestamps[key]
23
+ raise KeyError(f"{key} 已过期")
24
+ # 访问时更新活跃时间
25
+ self._timestamps[key] = time.time()
26
+ return super().__getitem__(key)
27
+
28
+ def get(self, key, default=None):
29
+ try:
30
+ return self.__getitem__(key)
31
+ except KeyError:
32
+ return default
33
+
34
+ def cleanup(self):
35
+ with self._lock:
36
+ now = time.time()
37
+ expired = [k for k, t in self._timestamps.items() if now - t > self.ttl]
38
+ for k in expired:
39
+ super().__delitem__(k)
40
+ del self._timestamps[k]
41
+
42
+ def start_auto_cleanup(self, interval=1):
43
+ def loop():
44
+ while True:
45
+ time.sleep(interval)
46
+ self.cleanup()
47
+ threading.Thread(target=loop, daemon=True).start()
48
+
49
+
50
+ class ExpiringDefaultDict(defaultdict):
51
+ """带过期时间的 defaultdict"""
52
+ def __init__(self, default_factory=None, ttl=600, *args, **kwargs):
53
+ super().__init__(default_factory, *args, **kwargs)
54
+ self.ttl = ttl
55
+ self._timestamps = {}
56
+ self._lock = threading.Lock()
57
+
58
+ def __setitem__(self, key, value):
59
+ with self._lock:
60
+ super().__setitem__(key, value)
61
+ self._timestamps[key] = time.time()
62
+
63
+ def __getitem__(self, key):
64
+ with self._lock:
65
+ if key in self._timestamps and (time.time() - self._timestamps[key] > self.ttl):
66
+ super().__delitem__(key)
67
+ del self._timestamps[key]
68
+ raise KeyError(f"{key} 已过期")
69
+ # 如果 key 不存在,则会调用 default_factory
70
+ val = super().__getitem__(key)
71
+ self._timestamps[key] = time.time()
72
+ return val
73
+
74
+ def get(self, key, default=None):
75
+ try:
76
+ return self.__getitem__(key)
77
+ except KeyError:
78
+ return default
79
+
80
+ def cleanup(self):
81
+ with self._lock:
82
+ now = time.time()
83
+ expired = [k for k, t in self._timestamps.items() if now - t > self.ttl]
84
+ for k in expired:
85
+ super().__delitem__(k)
86
+ del self._timestamps[k]
87
+
88
+ def start_auto_cleanup(self, interval=1):
89
+ def loop():
90
+ while True:
91
+ time.sleep(interval)
92
+ self.cleanup()
93
+ threading.Thread(target=loop, daemon=True).start()
MultiHeadAttention.py CHANGED
@@ -1,397 +1,406 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from Affine import Affine
7
-
8
- #获取相对位置矩阵
9
- def get_relative_mat(height,width,k=0):
10
- posi_i = np.arange(k,height+k) #列的范围
11
- posi_j = np.arange(0,width) #的范围
12
- posi_grid = np.meshgrid(posi_i, posi_j, indexing='ij')
13
- return abs(posi_grid[0]-posi_grid[1])
14
-
15
- #用于添加绝对位置信息的掩码
16
- def get_relative_dist(i,j,block_size,i_end,j_end):
17
- if block_size == 0:
18
- assert i==0 and j==0 ,"i!=0 or j!=0"
19
- return get_relative_mat(i_end,j_end,k=0)
20
- #i,j:当前分块的起始位置
21
- #block_size:分块大小
22
- #i_end,j_end:序列的长度
23
- height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
24
- width = block_size * 3 #度,也就是第个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
25
- #创建用来遮挡未来信息标准掩码
26
- #i越大,可见部分越多,j相反,+block_size是因为上一个分块可见。
27
- rela_dist = get_relative_mat(height,width,k=block_size+i-j)
28
- #边界超出处理
29
- #超出
30
- down_out = max(0,i+height-i_end)
31
- #左超出
32
- left_out = max(0,block_size-j)
33
- #右超出
34
- right_out = max(0,j+block_size*2-j_end)
35
- #边界内截取
36
- rela_dist = rela_dist[:height-down_out,left_out:width-right_out]
37
- return rela_dist.astype(np.float32)
38
-
39
- #用于添加绝对位置信息的掩码
40
- def get_absolute_mask(i,j,block_size,i_end,j_end):
41
- if block_size == 0:
42
- assert i==0 and j==0 ,"i!=0 or j!=0"
43
- return np.triu(np.ones((i_end,j_end),dtype='bool'), k=0)
44
- #i,j:当前分块的起始位置
45
- #block_size:分块大小
46
- #i_end,j_end:序列的长度
47
- height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
48
- width = block_size * 3 #度,也就是第个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
49
- #创建用来遮挡未来信息标准掩码
50
- #i越大,可见部分越多,j相反,+block_size是因为上一个分块可见。
51
- abs_mask = np.triu(np.ones((height,width),dtype='bool'), k=block_size+i-j)
52
- #边界超出处理
53
- #超出
54
- down_out = max(0,i+height-i_end)
55
- #左超出
56
- left_out = max(0,block_size-j)
57
- #右超出
58
- right_out = max(0,j+block_size*2-j_end)
59
- #边界内截取
60
- abs_mask = abs_mask[:height-down_out,left_out:width-right_out]
61
- return abs_mask
62
-
63
- #用于遮挡未来信息的标准掩码
64
- def get_std_mask(i,j,block_size,i_end,j_end):
65
- if block_size == 0:
66
- assert i==0 and j==0 ,"i!=0 or j!=0"
67
- return np.triu(np.ones((i_end,j_end),dtype='bool'), k=1) == False
68
- #i,j:当前分块的起始位置
69
- #block_size:分块大小
70
- #i_end,j_end:序列的长度
71
- height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
72
- width = block_size * 3 #度,也就是第个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
73
- #创建用来遮挡未来信息标准掩码
74
- #i越大,可见部分越多,j相反,+block_size是因为上一个分块可见。
75
- std_mask = np.triu(np.ones((height,width),dtype='bool'), k=1+block_size+i-j)
76
- #边界超出处理
77
- #超出
78
- down_out = max(0,i+height-i_end)
79
- #左超出
80
- left_out = max(0,block_size-j)
81
- #右超出
82
- right_out = max(0,j+block_size*2-j_end)
83
- #边界内截取
84
- std_mask = std_mask[:height-down_out,left_out:width-right_out]
85
- return std_mask == False
86
-
87
- #标记一个需要多次使用的tensor
88
- def ident(p_list):
89
- i,j,block_size,i_end,j_end = p_list[1:]
90
- ret = [p_list[0]]
91
- if p_list[0]=='r' or p_list[0]=='a':
92
- if block_size == 0:
93
- ret += [i_end,j_end,0]
94
- else:
95
- height = block_size
96
- width = block_size * 3
97
- ret += [height,width,block_size+i-j]
98
- down_out = max(0,i+height-i_end)
99
- left_out = max(0,block_size-j)
100
- right_out = max(0,j+block_size*2-j_end)
101
- ret += [height-down_out,left_out,width-right_out]
102
- else:
103
- if block_size == 0:
104
- ret += [i_end,j_end,1]
105
- else:
106
- height = block_size
107
- width = block_size * 3
108
- ret += [height,width,1+block_size+i-j]
109
- down_out = max(0,i+height-i_end)
110
- left_out = max(0,block_size-j)
111
- right_out = max(0,j+block_size*2-j_end)
112
- ret += [height-down_out,left_out,width-right_out]
113
- return str(ret)
114
-
115
- #缓存字典与定时器
116
- reg_dict = dict()
117
- reg_timer = dict()
118
-
119
- #查看是否未注册
120
- def un_reg(p):
121
- return not p in reg_dict
122
-
123
- #注册需要重复使用的tensor
124
- def reg(p,v):
125
- #找缓冲中用的最少的
126
- keys = [k for k in reg_dict]
127
- time_min = 0
128
- if len(keys) != 0:
129
- key_min = keys[0]
130
- time_min = reg_timer[key_min]
131
- for k in keys:
132
- if reg_timer[k]<time_min:
133
- key_min = k
134
- time_min = reg_timer[key_min]
135
- #计数
136
- if not p in reg_timer:
137
- reg_timer[p] = 1
138
- else:
139
- reg_timer[p] += 1
140
- #缓冲满了就删掉最少用的
141
- if len(keys) > 12:
142
- del reg_dict[key_min]
143
- #比最小的值大就保留
144
- if reg_timer[p] > time_min or len(keys) < 12:
145
- reg_dict[p] = v
146
-
147
- #从缓冲区中获取可重复使用的张量
148
- def get_reg(p):
149
- reg_timer[p] += 1
150
- return reg_dict[p]
151
-
152
-
153
- #多头注意力
154
- class MultiHeadAttention(nn.Module):
155
- def __init__(self,embedding_dim,key_dim,head_number,position_information_type,enable_affine,enable_talking_head, \
156
- self_attention_block_size,dropout_rate,enable_el_cache):
157
- super(MultiHeadAttention, self).__init__()
158
- self.embedding_dim = embedding_dim
159
- self.key_dim = key_dim
160
- self.head_number = head_number
161
- self.position_information_type = position_information_type
162
- self.enable_talking_head = enable_talking_head
163
- self.self_attention_block_size = self_attention_block_size
164
- self.dropout_layer = nn.Dropout(p=dropout_rate)
165
- self.enable_affine = enable_affine
166
-
167
- self.query_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
168
- self.key_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
169
- self.value_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
170
- self.out_w = nn.Linear(key_dim*head_number,embedding_dim,bias=False)
171
-
172
- self.enable_el_cache = enable_el_cache
173
- self.kv_cache = None
174
- self.temp = None
175
- self.cnt = 0
176
-
177
- if enable_affine == True:
178
- self.query_a = Affine(1.0)
179
- self.key_a = Affine(1.0)
180
- self.value_a = Affine(1.0)
181
- self.out_a = Affine(1.0)
182
-
183
- if enable_talking_head == True:
184
- self.talking_before_softmax = nn.Linear(head_number,head_number,bias=False)
185
- self.talking_after_softmax = nn.Linear(head_number,head_number,bias=False)
186
- else:
187
- self.talking_before_softmax = None
188
- self.talking_after_softmax = None
189
-
190
- if position_information_type == "mask":
191
- self.absolute_affine = Affine(1.0,grad_factor=1.0)
192
- self.relative_affine = Affine(0.1,grad_factor=1.0)
193
- else:
194
- self.absolute_affine = None
195
- self.relative_affine = None
196
-
197
- #注意力运算
198
- def attention(self, query, q_mask, key_value, mask_future):
199
- #为了使用EL-Attention需要修改参数传递方式
200
- absolute_affine = self.absolute_affine
201
- relative_affine = self.relative_affine
202
- talking_before_softmax = self.talking_before_softmax
203
- talking_after_softmax = self.talking_after_softmax
204
- block_size = self.self_attention_block_size
205
- #提前调整q_mask的形状,方便广播
206
- #query:[batch,head,query_len,emb_dim]
207
- #q_mask:[batch,query_len]
208
- #q_mask:[batch,query_len]->[batch,1,query_len]
209
- #q_mask:[batch,1,query_len]->[batch,head,query_len]
210
- q_mask = q_mask.unsqueeze(1).expand(*(query.size()[:-1]))
211
- #判断是否需要分块运算
212
- if block_size == 0:
213
- #不进行分块
214
- #计算scores
215
- scores = torch.matmul(query,key_value.transpose(-1,-2))
216
- if self.enable_affine == True:
217
- scores = scores+self.temp
218
- scores = scores/math.sqrt(self.key_dim)
219
- #尝试添加相对位置信息
220
- if relative_affine is not None:
221
- if self.enable_el_cache and query.size(-2) == 1:
222
- p = ident(['er',0,0,0,query.size(-2),key_value.size(-2)])
223
- if un_reg(p):
224
- rela_dist = np.arange(self.cnt,-1,-1).reshape(1,-1)
225
- rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
226
- reg(p,rela_dist)
227
- else:
228
- rela_dist = get_reg(p)
229
- else:
230
- p = ident(['r',0,0,0,query.size(-2),key_value.size(-2)])
231
- if un_reg(p):
232
- rela_dist = get_relative_dist(0,0,0,query.size(-2),key_value.size(-2))
233
- #直接广播更高效
234
- rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
235
- reg(p,rela_dist)
236
- else:
237
- rela_dist = get_reg(p)
238
- dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
239
- scores = scores.mul(dist_decay)
240
- #尝试添加绝对位置信息
241
- if absolute_affine is not None:
242
- if self.enable_el_cache and query.size(-2) == 1:
243
- p = ident(['ea',0,0,0,query.size(-2),key_value.size(-2)])
244
- if un_reg(p):
245
- abs_mask = np.array([[False]*(self.cnt)+[True]])
246
- abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
247
- reg(p,abs_mask)
248
- else:
249
- abs_mask = get_reg(p)
250
- else:
251
- p = ident(['a',0,0,0,query.size(-2),key_value.size(-2)])
252
- if un_reg(p):
253
- abs_mask = get_absolute_mask(0,0,0,query.size(-2),key_value.size(-2))
254
- #mask:[query_len,key_len]->[batch,head,query_len,key_len]
255
- abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
256
- reg(p,abs_mask)
257
- else:
258
- abs_mask = get_reg(p)
259
- abs_mask = abs_mask.expand(*(scores.size()))
260
- value_to_sub = absolute_affine(1.0)
261
- scores = torch.where(abs_mask == 0, scores - value_to_sub, scores)
262
- #遮挡信息之前先talk,这样数值稳定
263
- if talking_before_softmax is not None:
264
- scores = talking_before_softmax(scores.transpose(-1,-3)).transpose(-1,-3)
265
- #是否需要遮挡未来信息
266
- if mask_future == True:
267
- p = ident(['f',0,0,0,query.size(-2),key_value.size(-2)])
268
- if un_reg(p):
269
- #创建遮挡未来信息的掩码
270
- #mask:[query_len,key_len]->[batch,head,query_len,key_len]
271
- std_mask = get_std_mask(0,0,0,query.size(-2),key_value.size(-2))
272
- std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
273
- reg(p,std_mask)
274
- else:
275
- std_mask = get_reg(p)
276
- std_mask = std_mask.expand(*(scores.size()))
277
- #q_mask:[batch,head,query_len]->[batch,head,query_len,key_len]
278
- std_mask = q_mask.unsqueeze_(-1).expand(*(std_mask.size())) & std_mask
279
- scores.masked_fill_(std_mask == 0.0,-1e3)
280
- #计算概率权重
281
- p_attn = F.softmax(scores, dim = -1)
282
- #权重talk
283
- if talking_after_softmax is not None:
284
- p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
285
- if self.enable_affine:
286
- temp = p_attn.sum(dim=-1,keepdim=True)*self.value_a.bias*self.value_a.grad_factor
287
- #计算加权求和的结果
288
- ret = torch.matmul(p_attn, key_value)
289
- else:
290
- #分块时需要一个空间存放最终计算结果
291
- ret = torch.zeros_like(query)
292
- temp = torch.zeros_like(query[...,:1])
293
- #分块操作
294
- for i in range(0,query.size(-2),block_size):
295
- #进行分块
296
- query_block = query[...,i:i+block_size,:]
297
- q_mask_block = q_mask[...,i:i+block_size]
298
- key_value_block = key_value[...,max(0,i-block_size):i+block_size*2,:]
299
- #计算scores
300
- scores = torch.matmul(query_block,key_value_block.transpose(-1,-2))
301
- if self.enable_affine == True:
302
- scores = scores+self.temp[:,:,i:i+block_size]
303
- scores = scores/math.sqrt(self.key_dim)
304
- #尝试添加相对位置信息
305
- if relative_affine is not None:
306
- p = ident(['r',i,i,block_size,query.size(-2),key_value.size(-2)])
307
- if un_reg(p):
308
- rela_dist = get_relative_dist(i,i,block_size,query.size(-2),key_value.size(-2))
309
- rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
310
- reg(p,rela_dist)
311
- else:
312
- rela_dist = get_reg(p)
313
- # dist_decay= 1.0 / (1 + rela_dist*relative_affine(1.0))
314
- dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
315
- scores = scores.mul(dist_decay)
316
-
317
- #尝试添加绝对位置信息
318
- if absolute_affine is not None:
319
- p = ident(['a',i,i,block_size,query.size(-2),key_value.size(-2)])
320
- if un_reg(p):
321
- abs_mask = get_absolute_mask(i,i,block_size,query.size(-2),key_value.size(-2))
322
- abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
323
- reg(p,abs_mask)
324
- else:
325
- abs_mask = get_reg(p)
326
- abs_mask = abs_mask.expand(*(scores.size()))
327
- value_to_sub = absolute_affine(1.0)
328
- scores = torch.where(abs_mask == 0, scores - value_to_sub, scores)
329
-
330
- #遮挡信息之前先talk,这样数值稳定
331
- if talking_before_softmax is not None:
332
- scores = talking_before_softmax(scores.transpose(-1,-3)).transpose(-1,-3)
333
-
334
- #是否需要遮挡未来信息
335
- if mask_future == True:
336
- p = ident(['f',i,i,block_size,query.size(-2),key_value.size(-2)])
337
- if un_reg(p):
338
- #创建遮挡未来信息的掩码,因为是批次操作,需要进行升维
339
- std_mask = get_std_mask(i,i,block_size,query.size(-2),key_value.size(-2))
340
- std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
341
- reg(p,std_mask)
342
- else:
343
- std_mask = get_reg(p)
344
- std_mask = std_mask.expand(*(scores.size()))
345
- std_mask = q_mask_block.unsqueeze(-1).expand(*(std_mask.size())) & std_mask
346
- scores.masked_fill_(std_mask == 0.0,-1e3)
347
-
348
- #计算概率权重
349
- p_attn = F.softmax(scores, dim = -1)
350
-
351
- #权重talk
352
- if talking_after_softmax is not None:
353
- p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
354
- if self.enable_affine:
355
- temp[...,i:i+block_size,:] = p_attn.sum(dim=-1,keepdim=True)*self.value_a.bias*self.value_a.grad_factor
356
- #计算加权求和的结果
357
- ret[...,i:i+block_size,:] = torch.matmul(p_attn, key_value_block)
358
- if self.enable_affine:
359
- ret = ret * self.value_a.value * self.value_a.grad_factor
360
- ret = torch.matmul(ret,self.value_w.weight.view(self.head_number,self.key_dim,self.embedding_dim).transpose(1,2)) + temp
361
- return ret
362
-
363
- def forward(self, query, q_mask, key_value, mask_future):
364
- #采用EL-Attention方案
365
- if self.enable_el_cache:
366
- if query.size(-2) > 1:
367
- self.cnt = query.size(-2) - 1
368
- self.kv_cache = key_value
369
- else:
370
- self.cnt += 1
371
- self.kv_cache = torch.cat((self.kv_cache,key_value),1)
372
- key_value = self.kv_cache
373
- mask_future = False
374
- #经过线性变换得到真正的QKV
375
- query = self.query_w(query)
376
- batch_size = query.size(0)
377
- query = query.view(batch_size, -1, self.head_number, self.key_dim).transpose(1,2)
378
- #进行仿射变换,加快训练速度
379
- if self.enable_affine == True:
380
- query = self.query_a(query)
381
- self.temp = query.sum(dim=-1,keepdim=True)*self.key_a.bias*self.key_a.grad_factor
382
- query = query*self.key_a.value*self.key_a.grad_factor
383
- #划分注意力头
384
- query = torch.matmul(query,self.key_w.weight.view(self.head_number, self.key_dim, self.embedding_dim))
385
- key_value = key_value.view(batch_size,-1,1,self.embedding_dim).transpose(1,2)
386
- #query:[batch,head,seq_len,emd_dim]
387
- #key_value:[batch,1,seq_len,emd_dim]
388
- #计算多头注意力
389
- out = self.attention(query, q_mask, key_value, mask_future)
390
- self.temp = None
391
- #将计算完注意力的结果拼接回去
392
- out = out.transpose(1,2).contiguous().view(batch_size, -1, self.head_number * self.key_dim)
393
- if self.enable_affine:
394
- return self.dropout_layer(self.out_a(self.out_w(out)))
395
- else:
396
- return self.dropout_layer(self.out_w(out))
 
 
 
 
 
 
 
 
 
397
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from LazyCache import ExpiringDict, ExpiringDefaultDict
7
+ from Affine import Affine
8
+
9
+ #获取相对位置矩阵
10
+ def get_relative_mat(height,width,k=0):
11
+ posi_i = np.arange(k,height+k) #的范围
12
+ posi_j = np.arange(0,width) #行的范围
13
+ posi_grid = np.meshgrid(posi_i, posi_j, indexing='ij')
14
+ return abs(posi_grid[0]-posi_grid[1])
15
+
16
+ #用于添加绝对位置信息的掩码
17
+ def get_relative_dist(i,j,block_size,i_end,j_end):
18
+ if block_size == 0:
19
+ assert i==0 and j==0 ,"i!=0 or j!=0"
20
+ return get_relative_mat(i_end,j_end,k=0)
21
+ #i,j:当前分块的起始位置
22
+ #block_size:分块大小
23
+ #i_end,j_end:序列的长度
24
+ height = block_size #度,也就是第个序列中截取的长度,分块大小相等
25
+ width = block_size * 3 #宽度,也就是第二个序列中截取长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
26
+ #创建用来遮挡未来信息标准掩码
27
+ #i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
28
+ rela_dist = get_relative_mat(height,width,k=block_size+i-j)
29
+ #边界超出处理
30
+ #下超出
31
+ down_out = max(0,i+height-i_end)
32
+ #左超出
33
+ left_out = max(0,block_size-j)
34
+ #右超出
35
+ right_out = max(0,j+block_size*2-j_end)
36
+ #边界内截取
37
+ rela_dist = rela_dist[:height-down_out,left_out:width-right_out]
38
+ return rela_dist.astype(np.float32)
39
+
40
+ #用于添加绝对位置信息的掩码
41
+ def get_absolute_mask(i,j,block_size,i_end,j_end):
42
+ if block_size == 0:
43
+ assert i==0 and j==0 ,"i!=0 or j!=0"
44
+ return np.triu(np.ones((i_end,j_end),dtype='bool'), k=0)
45
+ #i,j:当前分块的起始位置
46
+ #block_size:分块大小
47
+ #i_end,j_end:序列的长度
48
+ height = block_size #度,也就是第个序列中截取的长度,分块大小相等
49
+ width = block_size * 3 #宽度,也就是第二个序列中截取长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
50
+ #创建用来遮挡未来信息标准掩码
51
+ #i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
52
+ abs_mask = np.triu(np.ones((height,width),dtype='bool'), k=block_size+i-j)
53
+ #边界超出处理
54
+ #下超出
55
+ down_out = max(0,i+height-i_end)
56
+ #左超出
57
+ left_out = max(0,block_size-j)
58
+ #右超出
59
+ right_out = max(0,j+block_size*2-j_end)
60
+ #边界内截取
61
+ abs_mask = abs_mask[:height-down_out,left_out:width-right_out]
62
+ return abs_mask
63
+
64
+ #用于遮挡未来信息的标准掩码
65
+ def get_std_mask(i,j,block_size,i_end,j_end):
66
+ if block_size == 0:
67
+ assert i==0 and j==0 ,"i!=0 or j!=0"
68
+ return np.triu(np.ones((i_end,j_end),dtype='bool'), k=1) == False
69
+ #i,j:当前分块的起始位置
70
+ #block_size:分块大小
71
+ #i_end,j_end:序列的长度
72
+ height = block_size #度,也就是第个序列中截取的长度,分块大小相等
73
+ width = block_size * 3 #宽度,也就是第二个序列中截取长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
74
+ #创建用来遮挡未来信息标准掩码
75
+ #i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
76
+ std_mask = np.triu(np.ones((height,width),dtype='bool'), k=1+block_size+i-j)
77
+ #边界超出处理
78
+ #下超出
79
+ down_out = max(0,i+height-i_end)
80
+ #左超出
81
+ left_out = max(0,block_size-j)
82
+ #右超出
83
+ right_out = max(0,j+block_size*2-j_end)
84
+ #边界内截取
85
+ std_mask = std_mask[:height-down_out,left_out:width-right_out]
86
+ return std_mask == False
87
+
88
+ #标记一个需要多次使用的tensor
89
+ def ident(p_list):
90
+ i,j,block_size,i_end,j_end = p_list[1:]
91
+ ret = [p_list[0]]
92
+ if p_list[0]=='r' or p_list[0]=='a':
93
+ if block_size == 0:
94
+ ret += [i_end,j_end,0]
95
+ else:
96
+ height = block_size
97
+ width = block_size * 3
98
+ ret += [height,width,block_size+i-j]
99
+ down_out = max(0,i+height-i_end)
100
+ left_out = max(0,block_size-j)
101
+ right_out = max(0,j+block_size*2-j_end)
102
+ ret += [height-down_out,left_out,width-right_out]
103
+ else:
104
+ if block_size == 0:
105
+ ret += [i_end,j_end,1]
106
+ else:
107
+ height = block_size
108
+ width = block_size * 3
109
+ ret += [height,width,1+block_size+i-j]
110
+ down_out = max(0,i+height-i_end)
111
+ left_out = max(0,block_size-j)
112
+ right_out = max(0,j+block_size*2-j_end)
113
+ ret += [height-down_out,left_out,width-right_out]
114
+ return str(ret)
115
+
116
+ #缓存字典与定时器
117
+ reg_dict = dict()
118
+ reg_timer = dict()
119
+
120
+ #查看是否未注册
121
+ def un_reg(p):
122
+ return not p in reg_dict
123
+
124
+ #注册需要重复使用的tensor
125
+ def reg(p,v):
126
+ #找缓冲中用的最少的
127
+ keys = [k for k in reg_dict]
128
+ time_min = 0
129
+ if len(keys) != 0:
130
+ key_min = keys[0]
131
+ time_min = reg_timer[key_min]
132
+ for k in keys:
133
+ if reg_timer[k]<time_min:
134
+ key_min = k
135
+ time_min = reg_timer[key_min]
136
+ #计数
137
+ if not p in reg_timer:
138
+ reg_timer[p] = 1
139
+ else:
140
+ reg_timer[p] += 1
141
+ #缓冲满了就删掉最少用的
142
+ if len(keys) > 12:
143
+ del reg_dict[key_min]
144
+ #比最小的值大就保留
145
+ if reg_timer[p] > time_min or len(keys) < 12:
146
+ reg_dict[p] = v
147
+
148
+ #从缓冲区中获取可重复使用的张量
149
+ def get_reg(p):
150
+ reg_timer[p] += 1
151
+ return reg_dict[p]
152
+
153
+
154
+ #多头注意力
155
+ class MultiHeadAttention(nn.Module):
156
+ def __init__(self,embedding_dim,key_dim,head_number,position_information_type,enable_affine,enable_talking_head, \
157
+ self_attention_block_size,dropout_rate,enable_el_cache):
158
+ super(MultiHeadAttention, self).__init__()
159
+ self.embedding_dim = embedding_dim
160
+ self.key_dim = key_dim
161
+ self.head_number = head_number
162
+ self.position_information_type = position_information_type
163
+ self.enable_talking_head = enable_talking_head
164
+ self.self_attention_block_size = self_attention_block_size
165
+ self.dropout_layer = nn.Dropout(p=dropout_rate)
166
+ self.enable_affine = enable_affine
167
+
168
+ self.query_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
169
+ self.key_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
170
+ self.value_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
171
+ self.out_w = nn.Linear(key_dim*head_number,embedding_dim,bias=False)
172
+
173
+ self.enable_el_cache = enable_el_cache
174
+ # 带有自动垃圾回收的字典
175
+ self.kv_cache = None
176
+ self.temp = None
177
+ self.cnt = None
178
+
179
+ if enable_affine == True:
180
+ self.query_a = Affine(1.0)
181
+ self.key_a = Affine(1.0)
182
+ self.value_a = Affine(1.0)
183
+ self.out_a = Affine(1.0)
184
+
185
+ if enable_talking_head == True:
186
+ self.talking_before_softmax = nn.Linear(head_number,head_number,bias=False)
187
+ self.talking_after_softmax = nn.Linear(head_number,head_number,bias=False)
188
+ else:
189
+ self.talking_before_softmax = None
190
+ self.talking_after_softmax = None
191
+
192
+ if position_information_type == "mask":
193
+ self.absolute_affine = Affine(1.0,grad_factor=1.0)
194
+ self.relative_affine = Affine(0.1,grad_factor=1.0)
195
+ else:
196
+ self.absolute_affine = None
197
+ self.relative_affine = None
198
+
199
+ #注意力运算
200
+ def attention(self, query, q_mask, key_value, mask_future, session_id):
201
+ #为了使用EL-Attention需要修改参数传递方式
202
+ absolute_affine = self.absolute_affine
203
+ relative_affine = self.relative_affine
204
+ talking_before_softmax = self.talking_before_softmax
205
+ talking_after_softmax = self.talking_after_softmax
206
+ block_size = self.self_attention_block_size
207
+ #提前调整q_mask的形状,方便广播
208
+ #query:[batch,head,query_len,emb_dim]
209
+ #q_mask:[batch,query_len]
210
+ #q_mask:[batch,query_len]->[batch,1,query_len]
211
+ #q_mask:[batch,1,query_len]->[batch,head,query_len]
212
+ q_mask = q_mask.unsqueeze(1).expand(*(query.size()[:-1]))
213
+ #判断是否需要分块运算
214
+ if block_size == 0:
215
+ #不进行分块
216
+ #计算scores
217
+ scores = torch.matmul(query,key_value.transpose(-1,-2))
218
+ if self.enable_affine == True:
219
+ scores = scores+self.temp[session_id]
220
+ scores = scores/math.sqrt(self.key_dim)
221
+ #尝试添加相对位置信息
222
+ if relative_affine is not None:
223
+ if self.enable_el_cache and query.size(-2) == 1:
224
+ p = ident(['er',0,0,0,query.size(-2),key_value.size(-2)])
225
+ if un_reg(p):
226
+ rela_dist = np.arange(self.cnt[session_id],-1,-1).reshape(1,-1)
227
+ rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
228
+ reg(p,rela_dist)
229
+ else:
230
+ rela_dist = get_reg(p)
231
+ else:
232
+ p = ident(['r',0,0,0,query.size(-2),key_value.size(-2)])
233
+ if un_reg(p):
234
+ rela_dist = get_relative_dist(0,0,0,query.size(-2),key_value.size(-2))
235
+ #直接广播更高效
236
+ rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
237
+ reg(p,rela_dist)
238
+ else:
239
+ rela_dist = get_reg(p)
240
+ dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
241
+ scores = scores.mul(dist_decay)
242
+ #尝试添加绝对位置信息
243
+ if absolute_affine is not None:
244
+ if self.enable_el_cache and query.size(-2) == 1:
245
+ p = ident(['ea',0,0,0,query.size(-2),key_value.size(-2)])
246
+ if un_reg(p):
247
+ abs_mask = np.array([[False]*(self.cnt[session_id])+[True]])
248
+ abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
249
+ reg(p,abs_mask)
250
+ else:
251
+ abs_mask = get_reg(p)
252
+ else:
253
+ p = ident(['a',0,0,0,query.size(-2),key_value.size(-2)])
254
+ if un_reg(p):
255
+ abs_mask = get_absolute_mask(0,0,0,query.size(-2),key_value.size(-2))
256
+ #mask:[query_len,key_len]->[batch,head,query_len,key_len]
257
+ abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
258
+ reg(p,abs_mask)
259
+ else:
260
+ abs_mask = get_reg(p)
261
+ abs_mask = abs_mask.expand(*(scores.size()))
262
+ value_to_sub = absolute_affine(1.0)
263
+ scores = torch.where(abs_mask == 0, scores - value_to_sub, scores)
264
+ #遮挡信息之前先talk,这样数值稳定
265
+ if talking_before_softmax is not None:
266
+ scores = talking_before_softmax(scores.transpose(-1,-3)).transpose(-1,-3)
267
+ #是否需要遮挡未来信息
268
+ if mask_future == True:
269
+ p = ident(['f',0,0,0,query.size(-2),key_value.size(-2)])
270
+ if un_reg(p):
271
+ #创建遮挡未来信息的掩码
272
+ #mask:[query_len,key_len]->[batch,head,query_len,key_len]
273
+ std_mask = get_std_mask(0,0,0,query.size(-2),key_value.size(-2))
274
+ std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
275
+ reg(p,std_mask)
276
+ else:
277
+ std_mask = get_reg(p)
278
+ std_mask = std_mask.expand(*(scores.size()))
279
+ #q_mask:[batch,head,query_len]->[batch,head,query_len,key_len]
280
+ std_mask = q_mask.unsqueeze_(-1).expand(*(std_mask.size())) & std_mask
281
+ scores.masked_fill_(std_mask == 0.0,-1e3)
282
+ #计算概率权重
283
+ p_attn = F.softmax(scores, dim = -1)
284
+ #权重talk
285
+ if talking_after_softmax is not None:
286
+ p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
287
+ if self.enable_affine:
288
+ temp = p_attn.sum(dim=-1,keepdim=True)*self.value_a.bias*self.value_a.grad_factor
289
+ #计算加权求和的结果
290
+ ret = torch.matmul(p_attn, key_value)
291
+ else:
292
+ #分块时需要一个空间存放最终计算结果
293
+ ret = torch.zeros_like(query)
294
+ temp = torch.zeros_like(query[...,:1])
295
+ #分块操作
296
+ for i in range(0,query.size(-2),block_size):
297
+ #进行分块
298
+ query_block = query[...,i:i+block_size,:]
299
+ q_mask_block = q_mask[...,i:i+block_size]
300
+ key_value_block = key_value[...,max(0,i-block_size):i+block_size*2,:]
301
+ #计算scores
302
+ scores = torch.matmul(query_block,key_value_block.transpose(-1,-2))
303
+ if self.enable_affine == True:
304
+ scores = scores+self.temp[session_id][:,:,i:i+block_size]
305
+ scores = scores/math.sqrt(self.key_dim)
306
+ #尝试添加相对位置信息
307
+ if relative_affine is not None:
308
+ p = ident(['r',i,i,block_size,query.size(-2),key_value.size(-2)])
309
+ if un_reg(p):
310
+ rela_dist = get_relative_dist(i,i,block_size,query.size(-2),key_value.size(-2))
311
+ rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
312
+ reg(p,rela_dist)
313
+ else:
314
+ rela_dist = get_reg(p)
315
+ # dist_decay= 1.0 / (1 + rela_dist*relative_affine(1.0))
316
+ dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
317
+ scores = scores.mul(dist_decay)
318
+
319
+ #尝试添加绝对位置信息
320
+ if absolute_affine is not None:
321
+ p = ident(['a',i,i,block_size,query.size(-2),key_value.size(-2)])
322
+ if un_reg(p):
323
+ abs_mask = get_absolute_mask(i,i,block_size,query.size(-2),key_value.size(-2))
324
+ abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
325
+ reg(p,abs_mask)
326
+ else:
327
+ abs_mask = get_reg(p)
328
+ abs_mask = abs_mask.expand(*(scores.size()))
329
+ value_to_sub = absolute_affine(1.0)
330
+ scores = torch.where(abs_mask == 0, scores - value_to_sub, scores)
331
+
332
+ #遮挡信息之前先talk,这样数值稳定
333
+ if talking_before_softmax is not None:
334
+ scores = talking_before_softmax(scores.transpose(-1,-3)).transpose(-1,-3)
335
+
336
+ #是否需要遮挡未来信息
337
+ if mask_future == True:
338
+ p = ident(['f',i,i,block_size,query.size(-2),key_value.size(-2)])
339
+ if un_reg(p):
340
+ #创建遮挡未来信息的掩码,因为是批次操作,需要进行升维
341
+ std_mask = get_std_mask(i,i,block_size,query.size(-2),key_value.size(-2))
342
+ std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
343
+ reg(p,std_mask)
344
+ else:
345
+ std_mask = get_reg(p)
346
+ std_mask = std_mask.expand(*(scores.size()))
347
+ std_mask = q_mask_block.unsqueeze(-1).expand(*(std_mask.size())) & std_mask
348
+ scores.masked_fill_(std_mask == 0.0,-1e3)
349
+
350
+ #计算概率权重
351
+ p_attn = F.softmax(scores, dim = -1)
352
+
353
+ #权重talk
354
+ if talking_after_softmax is not None:
355
+ p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
356
+ if self.enable_affine:
357
+ temp[...,i:i+block_size,:] = p_attn.sum(dim=-1,keepdim=True)*self.value_a.bias*self.value_a.grad_factor
358
+ #计算加权求和的结果
359
+ ret[...,i:i+block_size,:] = torch.matmul(p_attn, key_value_block)
360
+ if self.enable_affine:
361
+ ret = ret * self.value_a.value * self.value_a.grad_factor
362
+ ret = torch.matmul(ret,self.value_w.weight.view(self.head_number,self.key_dim,self.embedding_dim).transpose(1,2)) + temp
363
+ return ret
364
+
365
+ def forward(self, query, q_mask, key_value, mask_future, session_id):
366
+ #采用EL-Attention方案
367
+ if self.enable_el_cache:
368
+ if self.kv_cache is None:
369
+ self.kv_cache = ExpiringDict(ttl=600)
370
+ self.kv_cache.start_auto_cleanup()
371
+ self.temp = ExpiringDict(ttl=600)
372
+ self.temp.start_auto_cleanup()
373
+ self.cnt = ExpiringDefaultDict(int, ttl=600)
374
+ self.cnt.start_auto_cleanup()
375
+ if query.size(-2) > 1:
376
+ self.cnt[session_id] = query.size(-2) - 1
377
+ self.kv_cache[session_id] = key_value
378
+ else:
379
+ self.cnt[session_id] += 1
380
+ self.kv_cache[session_id] = torch.cat((self.kv_cache[session_id],key_value),1)
381
+ key_value = self.kv_cache[session_id]
382
+ mask_future = False
383
+ #经过线性变换得到真正的QKV
384
+ query = self.query_w(query)
385
+ batch_size = query.size(0)
386
+ query = query.view(batch_size, -1, self.head_number, self.key_dim).transpose(1,2)
387
+ #进行仿射变换,加快训练速度
388
+ if self.enable_affine == True:
389
+ query = self.query_a(query)
390
+ self.temp[session_id] = query.sum(dim=-1,keepdim=True)*self.key_a.bias*self.key_a.grad_factor
391
+ query = query*self.key_a.value*self.key_a.grad_factor
392
+ #划分注意力头
393
+ query = torch.matmul(query,self.key_w.weight.view(self.head_number, self.key_dim, self.embedding_dim))
394
+ key_value = key_value.view(batch_size,-1,1,self.embedding_dim).transpose(1,2)
395
+ #query:[batch,head,seq_len,emd_dim]
396
+ #key_value:[batch,1,seq_len,emd_dim]
397
+ #计算多头注意力
398
+ out = self.attention(query, q_mask, key_value, mask_future, session_id)
399
+ self.temp[session_id] = None
400
+ #将计算完注意力的结果拼接回去
401
+ out = out.transpose(1,2).contiguous().view(batch_size, -1, self.head_number * self.key_dim)
402
+ if self.enable_affine:
403
+ return self.dropout_layer(self.out_a(self.out_w(out)))
404
+ else:
405
+ return self.dropout_layer(self.out_w(out))
406
 
app.py CHANGED
@@ -1,221 +1,295 @@
1
- import time
2
- import uuid
3
- import html
4
- import threading
5
- import numpy as np
6
- import gradio as gr
7
- from queue import Queue
8
- from tokenizer import tokenizer,vocab_size,token2str
9
-
10
- import torch
11
- import torch.nn as nn
12
- from make_model import make_model
13
- from train_and_use import El_text_continue_stream
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- model = make_model(
17
- #token是从1开始的,0填充,剩下的用来覆盖全部字节
18
- vocab_size = vocab_size+1+255,
19
- embedding_dim = 768,
20
- key_dim = 128,
21
- head_number = 12,
22
- position_information_type = "mask",
23
- enable_affine = True,
24
- enable_talking_head = True,
25
- use_diff = False,
26
- self_attention_block_size = 0,
27
- feed_forward_dim = 1536,
28
- enable_layer_norm = True,
29
- deep = 12,
30
- dropout_rate = 0.1,
31
- enable_el_cache = True
32
- ).to(device)
33
- model.load_state_dict(torch.load('large_model_instruct_27567.weight',map_location=device,weights_only=True))
34
- model = model.eval()
35
-
36
- # 全局字典,存 per-session 的不可 deepcopy 对象 / 状态
37
- user_queues = {} # session_id -> Queue()
38
- user_stop_flags = {} # session_id -> bool (True 表示停止)
39
- user_current_sessions = {} # session_id -> 最后一个 session (list), 可 deepcopy
40
-
41
- # token包装器
42
- def token_wapper(token):
43
- return f'<span style="background-color: #FFD580; padding: 2px 4px; border-radius: 6px; margin: 1px; display: inline-block;">{html.escape(token)}</span>'
44
-
45
- def token_split_wapper(token):
46
- safe_token = html.escape(token)
47
- text = f"({safe_token})[单字多token]"
48
- return f'<span style="background-color: #FF0000; padding: 2px 4px; border-radius: 6px; margin: 1px; display: inline-block;">{text}</span>'
49
-
50
- # 后台生成函数(只访问全局字典通过 session_id 定位)
51
- def generate_text(user_message, session_id, temperature, repeat_penalty, max_length, decay):
52
- out = ""
53
- q = user_queues.get(session_id)
54
- if q is None:
55
- return
56
- # 通过分词器转化为token
57
- user_tokens = tokenizer(user_message,5.0)
58
- # token还原并进行包装
59
- words = []
60
- temp = []
61
- for token in user_tokens:
62
- if token > 0:
63
- if len(temp):
64
- words += [token_split_wapper(token2str(temp))]
65
- temp = []
66
- words += [token_wapper(token2str([token]))]
67
- else:
68
- temp += [token]
69
- if len(temp):
70
- words += [token_split_wapper(token2str(temp))]
71
- user_tokens = ''.join(words)
72
- # 准备模型输入
73
- if len(tokenizer(user_message,5.0)) < 2:
74
- user_message = f' {user_message}'
75
- tokens_batch = [tokenizer('<|im_start|>user '+user_message+'<|im_end|><|im_start|>assistant ',5.0)]
76
- tokens_batch = np.array(tokens_batch,dtype=np.int64)+255
77
- inputs = torch.from_numpy(tokens_batch).to(device).data
78
- last_len = -1
79
- # 模型输出
80
- with torch.no_grad():
81
- for o in El_text_continue_stream(
82
- model,inputs,out_length=max_length,
83
- repeat_penalty_value=repeat_penalty,
84
- temperature=temperature,
85
- decay=decay
86
- ):
87
- split = ''
88
- if o[0,-1] > 255: #确保是完整的字符才可以输出
89
- out += token2str(o[0][last_len:].cpu().numpy()-255,split=split)
90
- last_len = -1
91
- sess = [
92
- {"role": "user", "content": user_tokens},
93
- {"role": "assistant", "content": out},
94
- ]
95
- user_current_sessions[session_id] = sess
96
- try:
97
- q.put(sess, block=False)
98
- except:
99
- # 极少情况:队列放入失败(一般不会发生),忽略
100
- pass
101
- else:
102
- last_len -= 1
103
- if user_stop_flags.get(session_id, True):
104
- break
105
- if '<|im_end|>' in out:
106
- out = out.split('<|im_end|>')[0]
107
- sess = [
108
- {"role": "user", "content": user_tokens},
109
- {"role": "assistant", "content": out},
110
- ]
111
- user_current_sessions[session_id] = sess
112
- try:
113
- q.put(sess, block=False)
114
- except:
115
- # 极少情况:队列放入失败(一般不会发生),忽略
116
- pass
117
- break
118
-
119
- # 点击按钮的处理逻辑:start / stop / clear
120
- def click_process(sess, label, user_message, state, stop_flag_state, session_id, temperature, repeat_penalty, max_length, decay):
121
- # 安全检查
122
- if session_id is None or session_id not in user_queues:
123
- # session 还没初始化好,直接返回不改变 UI
124
- return "", "发送消息", state or {"current_session": []}, stop_flag_state or {"stop": True}, session_id
125
-
126
- # 如果现在处于"停止"状态并且有用户输入 -> 启动生成线程
127
- if stop_flag_state.get("stop", True) and user_message and sess == []:
128
- user_stop_flags[session_id] = False
129
- thread = threading.Thread(target=generate_text, args=(user_message, session_id, temperature, repeat_penalty, max_length, decay))
130
- thread.daemon = True
131
- thread.start()
132
- # 更新返回给前端的 state/stop_flag(gradio 会把这些值保存到 session state)
133
- return "", "终止输出", {"current_session": user_current_sessions.get(session_id, [])}, {"stop": False}, session_id
134
-
135
- # 如果正在输出 -> 终止
136
- elif not stop_flag_state.get("stop", True) and label != "清空会话":
137
- user_stop_flags[session_id] = True
138
- return user_message, "清空会话", {"current_session": user_current_sessions.get(session_id, [])}, {"stop": True}, session_id
139
-
140
- # 否则清空会话
141
- else:
142
- user_stop_flags[session_id] = True
143
- user_current_sessions[session_id] = []
144
- q = user_queues.get(session_id)
145
- if q:
146
- while not q.empty():
147
- try:
148
- q.get_nowait()
149
- except:
150
- break
151
- return user_message, "发送消息", {"current_session": []}, {"stop": True}, session_id
152
-
153
- # 流式输出 generator(只需触发一次即可一直运行)
154
- def stream_output(state, stop_flag_state):
155
- global user_queues,user_stop_flags,user_current_sessions
156
- # 页面加载时初始化 session(返回可 deepcopy 的 state 值和 session_id)
157
- session_id = str(uuid.uuid4())
158
- user_queues[session_id] = Queue()
159
- user_stop_flags[session_id] = True
160
- user_current_sessions[session_id] = [] # 初始为空会话
161
- # 返回给 gradio 的 state 值(这些都是 deepcopy-friendly)
162
- yield gr.update(), gr.update(), {"current_session": []}, {"stop": True}, session_id
163
- t0 = time.time()
164
- while True:
165
- q = user_queues[session_id]
166
- stopped = user_stop_flags.get(session_id, True)
167
- # 优先处理队列中的消息(FIFO)
168
- if (not stopped) and (not q.empty()):
169
- t0 = time.time()
170
- while q.qsize() > 5:
171
- sess = q.get()
172
- sess = q.get()
173
- # 更新 chatbot(返回的 sess 是 [{"role":...}, ...])
174
- # 同时把 state 返回为 deepcopy-friendly 字典(gr.State 需要可 deepcopied)
175
- yield sess, "终止输出", {"current_session": sess}, gr.update(), session_id
176
- else:
177
- last = user_current_sessions.get(session_id, [])
178
- if last == []:
179
- yield last, gr.update(), {"current_session": last},gr.update(), session_id
180
- else:
181
- if time.time() - t0 > 3:
182
- yield last, "清空会话", {"current_session": last},gr.update(), session_id
183
- time.sleep(0.1) # 防止 busy-wait 占满 CPU
184
-
185
- # ========== Gradio UI ==========
186
- with gr.Blocks() as demo:
187
- gr.Markdown("# LLM 在线体验(指令微调版)")
188
-
189
- chatbot = gr.Chatbot(type="messages", label="输入/输出", autoscroll=False, show_copy_button=False)
190
- msg = gr.Textbox(placeholder="请输问题。", label="用户问题输入", lines=4)
191
-
192
- with gr.Row():
193
- temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature")
194
- repeat_penalty = gr.Slider(0.0, 5.0, value=2.5, step=0.1, label="Repeat Penalty")
195
- max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length")
196
- decay = gr.Slider(0.90, 1.0, value=0.98, step=0.01, label="Repeat Penalty Decay Rate")
197
-
198
- btn = gr.Button("发送消息")
199
-
200
- # gr.State 用来在前端保存可 deepcopied 的 session 值
201
- state = gr.State()
202
- stop_flag_state = gr.State()
203
- session_id = gr.State()
204
-
205
- # 点击按钮处理 - 使用 session_id 定位用户资源
206
- btn.click(
207
- click_process,
208
- inputs=[chatbot, btn, msg, state, stop_flag_state, session_id, temperature, repeat_penalty, max_length, decay],
209
- outputs=[msg, btn, state, stop_flag_state, session_id],
210
- )
211
-
212
- # 页面加载后再触发 stream_output(只要触发一次,generator 会一直运行)
213
- demo.load(
214
- stream_output,
215
- inputs=[state, stop_flag_state],
216
- outputs=[chatbot, btn, state, stop_flag_state, session_id],
217
- )
218
-
219
- if __name__ == "__main__":
220
- demo.queue(max_size=128, default_concurrency_limit=128)
221
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 公开库
2
+ import time
3
+ import html
4
+ import uuid
5
+ import torch
6
+ import threading
7
+ import numpy as np
8
+ import gradio as gr
9
+ # 私有库
10
+ from queue import Queue
11
+ from make_model import make_model
12
+ from LazyCache import ExpiringDict
13
+ from train_and_use import El_text_continue_stream
14
+ from tokenizer import tokenizer,vocab_size,token2str
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # 加载模型
18
+ model = make_model(
19
+ #token是从1开始的,0填充,剩下的用来覆盖全部字节
20
+ vocab_size = vocab_size+1+255,
21
+ embedding_dim = 768,
22
+ key_dim = 128,
23
+ head_number = 12,
24
+ position_information_type = "mask",
25
+ enable_affine = True,
26
+ enable_talking_head = True,
27
+ use_diff = False,
28
+ self_attention_block_size = 0,
29
+ feed_forward_dim = 1536,
30
+ enable_layer_norm = True,
31
+ deep = 12,
32
+ dropout_rate = 0.1,
33
+ enable_el_cache = True
34
+ ).to(device)
35
+ model.load_state_dict(torch.load('large_model_instruct_09271556.weight',map_location=device,weights_only=True))
36
+ model = model.eval()
37
+
38
+ # token包装函数 - 使用HTML span标签确保每个token在独立矩形中
39
+ def token_wapper(token):
40
+ # 对特殊字符进行HTML转义处理
41
+ escaped_token = html.escape(token)
42
+ return f'<span class="token-box">{escaped_token}</span>'
43
+
44
+ # 多token包装函数 - 使用HTML span标签确保每个token在独立矩形中
45
+ def token_split_wapper(token):
46
+ # 对特殊字符进行HTML转义处理
47
+ escaped_token = html.escape(token)
48
+ return f'<span class="multi-token-box">({escaped_token})[多token]</span>'
49
+
50
+ # 处理用户输入的token返回安全的显示格式
51
+ def process_user_tokens(user_message):
52
+ # 通过分词器转化为token
53
+ user_tokens = tokenizer(user_message, 5.0)
54
+
55
+ # 将token还原并进行安全包装
56
+ words = [] # token列表
57
+ temp = [] # token是特殊字节,要合并
58
+ for token in user_tokens:
59
+ if token > 0:
60
+ # 将合并成功的加入列表
61
+ if len(temp):
62
+ words.append(token_split_wapper(token2str(temp)))
63
+ temp = []
64
+ # 将新的token加入列表
65
+ words.append(token_wapper(token2str([token])))
66
+ else:
67
+ # 将字节送去合并
68
+ temp.append(token)
69
+ # 结束的时候要进行收尾
70
+ if len(temp):
71
+ words.append(token_split_wapper(token2str(temp)))
72
+ # 返回包装好的token列表
73
+ return ''.join(words)
74
+
75
+ # 全局字典,存 per-session 的不可 deepcopy 对象 / 状态
76
+ user_queues = ExpiringDict(ttl=550) # session_id -> Queue(list([string,string])),用于流式输出
77
+ user_queues.start_auto_cleanup()
78
+ user_stop_flags = ExpiringDict(ttl=550) # session_id -> bool (True 表示停止)
79
+ user_stop_flags.start_auto_cleanup()
80
+ user_history_sessions_show = ExpiringDict(ttl=550) # session_id -> 用于显示的历史记录,list([string,string])
81
+ user_history_sessions_show.start_auto_cleanup()
82
+ user_history_sessions_text = ExpiringDict(ttl=550) # session_id -> 纯文本历史记录,string
83
+ user_history_sessions_text.start_auto_cleanup()
84
+
85
+ # 后台生成函数(只访问全局字典,通过 session_id 定位)
86
+ def generate_text(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay):
87
+ out = ""
88
+ q = user_queues.get(session_id)
89
+ # 立即刷出用户问题
90
+ q.put(out, block=False)
91
+ # 构建完整的对话历史输入
92
+ if len(sess) == 1:
93
+ user_history_sessions_text[session_id] = f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
94
+ else:
95
+ user_history_sessions_text[session_id] += f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
96
+ # 转换为模型输入格式
97
+ tokens_batch = [tokenizer(user_history_sessions_text[session_id], 5.0)]
98
+ tokens_batch = np.array(tokens_batch, dtype=np.int64) + 255
99
+ inputs = torch.from_numpy(tokens_batch).to(device).data
100
+ last_len = -1
101
+ # 模型输出
102
+ with torch.no_grad():
103
+ for o in El_text_continue_stream(
104
+ model, inputs, out_length=max_length,
105
+ repeat_penalty_value=repeat_penalty,
106
+ temperature=temperature,decay=decay,session_id=session_id):
107
+ # 如果当前位置可以完整解码
108
+ if o[0,-1] > 255:
109
+ # 将未解码的部分一起解码
110
+ temp = token2str(o[0][last_len:].cpu().numpy()-255)
111
+ out += temp
112
+ user_history_sessions_text[session_id] += temp
113
+ # 重置为解码光标
114
+ last_len = -1
115
+ q.put(out, block=False)
116
+ else:
117
+ # 无法解码,光标固定
118
+ last_len -= 1
119
+ # 如果用户主动断开连接,停止生成,去除潜在标记
120
+ if user_stop_flags.get(session_id, True):
121
+ if '<' + out.split('<')[-1] in '<|im_end|>':
122
+ # 显示的部分去除标记
123
+ out = '<'+'<'.join(out.split('<')[:-1])
124
+ # 历史的部分保留标记
125
+ user_history_sessions_text[session_id] = '<'+'<'.join(user_history_sessions_text[session_id].split('<')[:-1])+'<|im_end|>'
126
+ break
127
+
128
+ # 如果是输出终止标记
129
+ if '<|im_end|>' in out:
130
+ # 显示的部分,去除标记
131
+ out = out.split('<|im_end|>')[0]
132
+ q.put(out, block=False)
133
+ break
134
+ # 如果用户中断
135
+ if user_stop_flags[session_id] == True:
136
+ break
137
+ # 更新标记为暂停
138
+ user_stop_flags[session_id] = True
139
+
140
+ # 按钮处理逻辑:发送消息 / 停止生成 / 清空会话
141
+ def send_message(sess, btn_label, user_message, session_id, temperature, repeat_penalty, max_length, decay):
142
+ # 发送消息按钮 - 启动生成线程
143
+ if btn_label == "发送消息" and user_message:
144
+ # 设置当前用户正在生成的标志
145
+ user_stop_flags[session_id] = False
146
+ # 立即在UI中显示用户消息
147
+ user_tokens_display = process_user_tokens(user_message)
148
+ # 添加用户消息到当前会话
149
+ user_history_sessions_show[session_id] = sess
150
+ user_history_sessions_show[session_id] += [[user_tokens_display, ""]]
151
+ if session_id not in user_history_sessions_text:
152
+ return "", "会话过期!"
153
+ # 在这里开始流式输出
154
+ thread = threading.Thread(target=generate_text, args=(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay))
155
+ thread.daemon = True #主进程退出时退出
156
+ thread.start() #启动
157
+ user_stop_flags[session_id] = False
158
+ # 更新返回给前端的 state/stop_flag
159
+ return "", "停止生成"
160
+ else:
161
+ # 停止生成按钮 - 设置标志位
162
+ user_stop_flags[session_id] = True
163
+ # 更新返回给前端的 state/stop_flag
164
+ return user_message, "发送消息"
165
+
166
+ # 清空会话
167
+ def clear_session():
168
+ return []
169
+
170
+ # 流式输出,无限循环刷新页面
171
+ def stream_output(sess):
172
+ global user_queues, user_stop_flags, user_history_sessions_show, user_history_sessions_text
173
+ # 页面加载时初始化 session
174
+ session_id = str(uuid.uuid4())
175
+ user_queues[session_id] = Queue()
176
+ user_stop_flags[session_id] = True
177
+ user_history_sessions_show[session_id] = [] # 初始化历史会话记录,用于显示
178
+ user_history_sessions_text[session_id] = "" # 初始化历史会话记录,用于文本存储
179
+ # 返回初始状态
180
+ yield [], "发送消息", session_id
181
+ # 不断刷新
182
+ while True:
183
+ time.sleep(0.01) # 防止 busy-wait 占满 CPU
184
+ # 等待队列有数据
185
+ q = user_queues.get(session_id)
186
+ if q is None:
187
+ continue
188
+ # 处理队列中的消息
189
+ if not q.empty():
190
+ # 取到最后一个加入的数据
191
+ while q.qsize() > 1:
192
+ q.get()
193
+ out = q.get()
194
+ sess = user_history_sessions_show[session_id]
195
+ sess[-1][1] = out
196
+ # 更新UI状态
197
+ current_stopped = user_stop_flags.get(session_id, True)
198
+ button_label = "停止生成" if not current_stopped else "发送消息"
199
+ yield sess, button_label, session_id
200
+
201
+ # UI美化
202
+ css = """
203
+ /* 大标题居中 */
204
+ .title {
205
+ text-align: center;
206
+ }
207
+ /* 高级选项字体居中 */
208
+ #adv-param button {
209
+ justify-content: center;
210
+ }
211
+ /* 高级选项字体放大 */
212
+ #adv-param > button > span {
213
+ font-size: 16px !important;
214
+ font-weight: 600 !important;
215
+ }
216
+ /* 自定义token样式 */
217
+ .token-box {
218
+ display: inline-block;
219
+ background-color: #f0f0f0;
220
+ border: 1px solid #ddd;
221
+ border-radius: 4px;
222
+ padding: 2px 4px;
223
+ margin: 2px;
224
+ font-family: monospace;
225
+ }
226
+ .multi-token-box {
227
+ display: inline-block;
228
+ background-color: #e6f7ff;
229
+ border: 1px solid #91d5ff;
230
+ border-radius: 4px;
231
+ padding: 2px 4px;
232
+ margin: 2px;
233
+ font-family: monospace;
234
+ }
235
+ """
236
+ # ========== Gradio UI ==========
237
+ with gr.Blocks(css=css) as demo:
238
+ with gr.Column(elem_classes="container"):
239
+ gr.Markdown("# 0.18B中文大语言模型在线体验", elem_classes="title")
240
+ # 聊天界面
241
+ chatbot = gr.Chatbot(
242
+ label="对话",
243
+ autoscroll=False,
244
+ show_copy_button=True,
245
+ elem_classes="chatbox",
246
+ type="tuples",
247
+ height=400
248
+ )
249
+ # 输入区域
250
+ with gr.Column(elem_classes="input-area"):
251
+ msg = gr.Textbox(
252
+ placeholder="请输入你的问题...",
253
+ label="",
254
+ lines=3,
255
+ show_label=False
256
+ )
257
+ # 按钮区域
258
+ with gr.Row(elem_classes="button-row"):
259
+ send_btn = gr.Button("发送消息", elem_classes="send-btn")
260
+ clear_btn = gr.Button("清空会话", elem_classes="clear-btn")
261
+ # 参数设置区域(可折叠)
262
+ with gr.Accordion("高级参数设置", open=False, elem_classes="parameter-row", elem_id="adv-param"):
263
+ with gr.Row():
264
+ temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature")
265
+ repeat_penalty = gr.Slider(0.0, 5.0, value=2.5, step=0.1, label="Repeat Penalty")
266
+ with gr.Row():
267
+ max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length")
268
+ decay = gr.Slider(0.90, 1.0, value=0.98, step=0.01, label="Repeat Penalty Decay Rate")
269
+ # gr.State 用来在前端保存可 deepcopied 的 session 值
270
+ session_id = gr.State()
271
+ # 发送按钮处理
272
+ send_btn.click(
273
+ send_message,
274
+ inputs=[chatbot, send_btn, msg, session_id, temperature, repeat_penalty, max_length, decay],
275
+ outputs=[msg, send_btn],
276
+ )
277
+
278
+ clear_btn.click(
279
+ clear_session,
280
+ inputs=[],
281
+ outputs=[chatbot],
282
+ )
283
+
284
+ # 无限循环,一直更新聊天界面
285
+ demo.load(
286
+ stream_output,
287
+ inputs=[chatbot],
288
+ outputs=[chatbot, send_btn, session_id],
289
+ )
290
+ if __name__ == "__main__":
291
+ """主函数:启动Gradio界面"""
292
+ # 设置队列参数以提高并发处理能力
293
+ demo.queue(max_size=128, default_concurrency_limit=128)
294
+ # 启动Gradio应用,不公开分享,并应用CSS样式
295
+ demo.launch(share=False)
train_and_use.py CHANGED
@@ -1,444 +1,444 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import time
6
- import threading
7
- import copy
8
- class Batch:
9
- def __init__(self,input_sequences):
10
- self.data_type = "generator"
11
- self.query = input_sequences[...,:-1]
12
- self.label = input_sequences[...,1:]
13
- self.q_mask = self.query != 0
14
- self.ntokens = float((self.label != 0).sum())
15
-
16
- #交叉熵损失,“0”填充特殊处理
17
- class CrossEntropyLoss(nn.Module):
18
- def __init__(self):
19
- super(CrossEntropyLoss, self).__init__()
20
- # 使用KL散度损失函数(接受对数概率分布x和概率分布y,并不是简单的KL散度计算)
21
- self.criterion = nn.KLDivLoss(reduction='sum')
22
-
23
- def forward(self, model_output_dist, target_sequence):
24
- #根据模型输出的分布与标签的分布计算交叉熵损失
25
- #为目标分布分配和模型输出形状、类型一样的空间,默认不追踪梯度,写明更清晰
26
- true_dist = torch.zeros_like(model_output_dist,requires_grad=False)
27
- #使用置信度填充目标词的位置(true_dist是词表那么长的概率分布)
28
- #目标序列升维,target_sequence:[batch*len]->[batch*len,1]
29
- #true_dist:[batch*len,vocab]
30
- #在vocab的维度上用标签值当作索引,找到对应元素,填充1.0
31
- true_dist.scatter_(1, target_sequence.data.unsqueeze(1), 1.0)
32
- #将填充位置概率设为0
33
- true_dist[:,0] = 0.0
34
- #计算模型输出分布与平目标序列标签平滑后的分布之间的交叉熵
35
- #model_output_dist是对数概率分布,应由F.log_softmax(self.project(x),dim=-1)产生
36
- #但实际上为了压缩softmax的值域已达到自动丢弃异常值的效果,在Generator.Projector进行了特殊实现
37
- return self.criterion(model_output_dist, true_dist)
38
-
39
- class AdamOptimizerWithBase:
40
- "带有Base的自适应矩估计优化器"
41
- def __init__(self, params, base, half_life, betas, eps):
42
- self.beta1 = betas[0]
43
- self.beta2 = betas[1]
44
- self.beta3 = (1/2)**(1/half_life)
45
- self.epsilon = eps
46
- self.t = 0
47
- self.param_groups = []
48
- for p,b in zip(params,base):
49
- self.param_groups.append({
50
- 'params': p,
51
- 'lr' : 0.0,
52
- 'm' : torch.zeros_like(p).detach(),
53
- 'v' : torch.zeros_like(p).detach(),
54
- 'b' : b.clone().detach()
55
- })
56
-
57
- def step(self):
58
- self.t += 1
59
- for group in self.param_groups:
60
- # 获取梯度
61
- grad = group['params'].grad
62
- if grad is None:
63
- continue
64
- with torch.no_grad():
65
- # 历史衰减
66
- group['m'].mul_(self.beta1).add_(grad, alpha = 1 - self.beta1)
67
- group['v'].mul_(self.beta2).addcmul_(grad, grad, value = 1 - self.beta2)
68
- # 偏差纠正
69
- m_hat = group['m'] / (1 - self.beta1 ** self.t)
70
- v_hat = group['v'] / (1 - self.beta2 ** self.t)
71
- # 参数更新
72
- group['params'].sub_(group['lr'] / (v_hat.sqrt() + self.epsilon) * m_hat).mul_(self.beta3).add_(group['b'],alpha = 1 - self.beta3)
73
-
74
- def zero_grad(self):
75
- for group in self.param_groups:
76
- if group['params'].grad is not None:
77
- group['params'].grad.detach_()
78
- group['params'].grad.zero_()
79
-
80
- def refresh(self):
81
- for group in self.param_groups:
82
- group['m'] = torch.zeros_like(group['params']).detach()
83
- group['v'] = torch.zeros_like(group['params']).detach()
84
- group['b'] = group['params'].clone().detach()
85
- self.t = 0
86
-
87
- class SimpleAdamOptimizer:
88
- "简单的自适应矩估计优化器"
89
- def __init__(self, params, betas, eps):
90
- self.beta1 = betas[0]
91
- self.beta2 = betas[1]
92
- self.epsilon = eps
93
- self.t = 0
94
- self.param_groups = []
95
- for p in params:
96
- self.param_groups.append({
97
- 'params': p,
98
- 'lr' : 0.0,
99
- 'm' : torch.zeros_like(p).detach(),
100
- 'v' : torch.zeros_like(p).detach()
101
- })
102
-
103
- def step(self):
104
- self.t += 1
105
- for group in self.param_groups:
106
- # 获取梯度
107
- grad = group['params'].grad
108
- if grad is None:
109
- continue
110
- grad[grad!=grad] = 0.0
111
- grad[grad>100] = 100.0
112
- grad[grad<-100] = -100.0
113
- with torch.no_grad():
114
- # 历史衰减
115
- group['m'].mul_(self.beta1).add_(grad, alpha = 1 - self.beta1)
116
- group['v'].mul_(self.beta2).addcmul_(grad, grad, value = 1 - self.beta2)
117
- # 偏差纠正
118
- m_hat = group['m'] / (1 - self.beta1 ** self.t)
119
- v_hat = group['v'] / (1 - self.beta2 ** self.t)
120
- # 参数更新
121
- group['params'].sub_(group['lr'] / (v_hat.sqrt() + self.epsilon) * m_hat)
122
-
123
- def zero_grad(self):
124
- for group in self.param_groups:
125
- if group['params'].grad is not None:
126
- group['params'].grad.detach_()
127
- group['params'].grad.zero_()
128
-
129
- def get_lrate(start_step,total_step,lr_from,lr_to,transition,enable_wave):
130
- assert transition > 0 and transition % 2 == 0, "Need transition lt 0 and transition mod 2 eq 0."
131
- mid_transition = transition // 2
132
- half_lr_gap = (lr_to - lr_from)/2
133
- if total_step >= start_step + transition:
134
- ret = lr_to
135
- elif total_step < start_step + mid_transition:
136
- ret = lr_from + half_lr_gap * (total_step - start_step)**2 / mid_transition**2
137
- else:
138
- ret = lr_to - half_lr_gap * (start_step + transition - total_step)**2 / mid_transition**2
139
- #最后的时候震荡,否则有危害
140
- if ret != lr_to or enable_wave == False or lr_to > 2e-4:
141
- return ret
142
- else:
143
- return ret + np.sin((total_step - start_step) * np.pi / mid_transition) * lr_to * 0.9
144
-
145
- record = {
146
- "loss_line" : [],
147
- "lr_line" : []
148
- }
149
-
150
- class OptimizerWrapper:
151
- def __init__(self, optimizer, warm_up, lr, enable_wave = False):
152
- self.lr_from = 0 #初始学习率
153
- self.lr_to = lr #目标学习率
154
- self.warm_up = warm_up #预热步数
155
- self.start_step= 0 #起始步数
156
- self.total_step= 0 #总步数
157
- self.optimizer = optimizer #优化器,用于执行梯度下降
158
- self.enable_wave = enable_wave #学习率波动
159
-
160
- def update(self):
161
- global record
162
- #设置优化器中每个参数组的学习率并执行梯度下降
163
- lrate = self.lrate()
164
- record["lr_line"] += [lrate]
165
- for parameters in self.optimizer.param_groups:
166
- parameters['lr'] = lrate
167
- self.optimizer.step()
168
- self.optimizer.zero_grad()
169
-
170
- def lrate(self):
171
- self.total_step += 1
172
- return get_lrate(
173
- self.start_step,
174
- self.total_step,
175
- self.lr_from,
176
- self.lr_to,
177
- self.warm_up,
178
- self.enable_wave)
179
-
180
- def set_lrate(self,lrate,transition):
181
- self.lr_from = self.lr_to
182
- self.lr_to = lrate
183
- self.warm_up = transition
184
- self.start_step = self.total_step
185
-
186
- stop = False
187
- pause = False
188
-
189
- def run_epoch(model,data_iter,caculate_size,loss_f,optimizer,epoch,use_amp):
190
- global stop
191
- global pause
192
- global record
193
- for step, batch in enumerate(data_iter):
194
- if stop:
195
- break
196
- while pause:
197
- time.sleep(0.5)
198
- total_loss = 0
199
- t_start = time.time()
200
- for i in range(0,batch.query.size(0),caculate_size):
201
- if use_amp:
202
- with torch.amp.autocast("cuda"):
203
- model_output = model(batch.query[i:i+caculate_size], batch.q_mask[i:i+caculate_size])
204
- loss = loss_f(torch.log(F.softmax(model_output,dim=-1).mul(0.99).add(5e-3)).view(-1,model_output.size(-1)),
205
- batch.label[i:i+caculate_size].reshape(-1))/ batch.ntokens
206
- loss.backward()
207
- total_loss += float(loss) * batch.ntokens
208
- else:
209
- model_output = model(batch.query[i:i+caculate_size], batch.q_mask[i:i+caculate_size])
210
- loss = loss_f(torch.log(F.softmax(model_output,dim=-1).mul(0.99).add(5e-3)).view(-1,model_output.size(-1)),
211
- batch.label[i:i+caculate_size].reshape(-1))/ batch.ntokens
212
- loss.backward()
213
- total_loss += float(loss) * batch.ntokens
214
- optimizer.update()
215
- mean_loss = total_loss/batch.ntokens
216
- record["loss_line"] += [mean_loss]
217
- t_end = time.time()
218
- print('\repoch:',epoch,'\tstep:',step,'\tloss:',str(mean_loss)[:5],'\tspeed:',str(batch.ntokens/(t_end - t_start))[:7],'tokens/s',end = ' '*20)
219
-
220
-
221
- #训练函数以服务模式运行,可以随时手动调整
222
- def train(model,data_generator,batch_size,caculate_size,loss_f,optimizer,use_amp):
223
- global stop
224
- epoch = 0
225
- while(True):
226
- if stop:
227
- break
228
- run_epoch(model,data_generator(batch_size),caculate_size,loss_f,optimizer,epoch,use_amp)
229
- epoch += 1
230
-
231
- #启动训练服务
232
- def train_server_start(model,generator_batch_pair,split_n,loss_f,optimizer,use_amp = False):
233
- assert generator_batch_pair[1] % split_n == 0, "Need batch_size mod split_n eq 0."
234
- data_generator,batch_size = generator_batch_pair
235
- thread = threading.Thread(target=train,args=(model,data_generator,batch_size,batch_size//split_n,loss_f,optimizer,use_amp))
236
- thread.start()
237
-
238
- def TOGGLE():
239
- global pause
240
- pause = not pause
241
- print("pause:",pause)
242
-
243
- def STOP():
244
- global stop
245
- stop = True
246
-
247
- #贪婪解码
248
- def greedy_decode(model,inputs,out_length):
249
- if model.model_type == "generator":
250
- for _ in range(out_length):
251
- query = model.embedding(inputs)
252
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1:,:])
253
- next_token = torch.max(prob_dist, dim = -1)[1]
254
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
255
- return inputs
256
-
257
- def El_greedy_decode(model,inputs,out_length):
258
- if model.model_type == "generator":
259
- assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
260
- query = model.embedding(inputs)
261
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1:,:])
262
- next_token = torch.max(prob_dist, dim = -1)[1]
263
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
264
- for _ in range(0,out_length-1,1):
265
- query = model.embedding(inputs[:,[-1]])
266
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1:,:])
267
- next_token = torch.max(prob_dist, dim = -1)[1]
268
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
269
- return inputs
270
-
271
- #概率解码
272
- def sampling_decode(model,inputs,out_length):
273
- if model.model_type == "generator":
274
- for _ in range(out_length):
275
- query = model.embedding(inputs)
276
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
277
- next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
278
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
279
- return inputs
280
-
281
- def El_sampling_decode(model,inputs,out_length):
282
- if model.model_type == "generator":
283
- assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
284
- query = model.embedding(inputs)
285
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
286
- next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
287
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
288
- for _ in range(0,out_length-1,1):
289
- query = model.embedding(inputs[:,[-1]])
290
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
291
- next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
292
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
293
- return inputs
294
-
295
- #更可控的文本续写工具
296
- def text_continue(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98):
297
- if model.model_type == "generator":
298
- repeat_penalty = None
299
- for _ in range(out_length):
300
- query = model.embedding(inputs)
301
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
302
- if repeat_penalty is None:
303
- repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
304
- for index in range(inputs.size(1)):
305
- for line in range(inputs.size(0)):
306
- repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
307
- repeat_penalty *= decay
308
- else:
309
- repeat_penalty *= decay
310
- prob_dist += repeat_penalty
311
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
312
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
313
- for i in range(next_token.size(0)):
314
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
315
- return inputs
316
-
317
- def El_text_continue(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98):
318
- if model.model_type == "generator":
319
- assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
320
- query = model.embedding(inputs)
321
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
322
- repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
323
- for index in range(inputs.size(1)):
324
- for line in range(inputs.size(0)):
325
- repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
326
- repeat_penalty *= decay
327
- prob_dist += repeat_penalty
328
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
329
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
330
- for i in range(next_token.size(0)):
331
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
332
- for _ in range(0,out_length-1,1):
333
- query = model.embedding(inputs[:,[-1]])
334
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
335
- repeat_penalty *= decay
336
- prob_dist += repeat_penalty
337
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
338
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
339
- for i in range(next_token.size(0)):
340
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
341
- return inputs
342
-
343
- def El_text_continue_stream(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98):
344
- if model.model_type == "generator":
345
- assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
346
- query = model.embedding(inputs)
347
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
348
- repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
349
- for index in range(inputs.size(1)):
350
- for line in range(inputs.size(0)):
351
- repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
352
- repeat_penalty *= decay
353
- prob_dist += repeat_penalty
354
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
355
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:]
356
- yield inputs
357
- for i in range(next_token.size(0)):
358
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
359
- for _ in range(0,out_length-1,1):
360
- query = model.embedding(inputs[:,[-1]])
361
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
362
- repeat_penalty *= decay
363
- prob_dist += repeat_penalty
364
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
365
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:] #留下最后4个字就足够了(utf-8最长是4字节)
366
- for i in range(next_token.size(0)):
367
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
368
- yield inputs
369
-
370
- #值函数,给基于蒙特卡洛树的续写用
371
- def text_continue_value(model,inputs,out_length,repeat_penalty,repeat_penalty_value,temperature,decay):
372
- if model.model_type == "generator":
373
- ret = 0
374
- assert len(inputs[0]) > 1,"初始序列长度必须大于1,与增量续写进行区分"
375
- query = model.embedding(inputs)
376
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
377
- prob_dist += repeat_penalty
378
- repeat_penalty *= decay
379
- prob_dist = F.softmax(prob_dist/temperature, dim = -1)
380
- next_token = torch.multinomial(prob_dist, num_samples = 1)
381
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
382
- for i in range(next_token.size(0)):
383
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
384
- ret += prob_dist[i,next_token[i]]
385
- for _ in range(0,out_length-1,1):
386
- query = model.embedding(inputs[:,[-1]])
387
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
388
- prob_dist += repeat_penalty
389
- repeat_penalty *= decay
390
- prob_dist = F.softmax(prob_dist/temperature, dim = -1)
391
- next_token = torch.multinomial(prob_dist, num_samples = 1)
392
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
393
- for i in range(next_token.size(0)):
394
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
395
- ret += prob_dist[i,next_token[i]]
396
- return ret
397
-
398
- #基于蒙特卡洛树的续写
399
- def MC_continue(model,inputs,out_length,repeat_penalty_value,temperature,try_n,acc_n,deep_n,decay=0.98):
400
- if model.model_type == "generator":
401
- repeat_penalty = None
402
- assert inputs.dim() == 1, "不支持并行续写!Need inputs.dim eq 1"
403
- #复制多份进行树搜索
404
- values = [0] * try_n
405
- inputs = inputs.repeat(try_n,1)
406
- query = model.embedding(inputs)
407
- prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
408
- repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
409
- for index in range(inputs.size(1)):
410
- for line in range(inputs.size(0)):
411
- repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
412
- repeat_penalty *= decay
413
- prob_dist += repeat_penalty
414
- prob_dist = F.softmax(prob_dist/temperature, dim = -1)
415
- next_token = torch.multinomial(prob_dist, num_samples = 1)
416
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
417
- for i in range(try_n):
418
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
419
- values[i] += prob_dist[i,next_token[i]]
420
- for cur in range(0,out_length-1,1):
421
- query = model.embedding(inputs[:,[-1]])
422
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
423
- repeat_penalty *= decay
424
- prob_dist += repeat_penalty
425
- prob_dist = F.softmax(prob_dist/temperature, dim = -1)
426
- next_token = torch.multinomial(prob_dist, num_samples = 1)
427
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
428
- for i in range(try_n):
429
- repeat_penalty[i][next_token[i]] -= repeat_penalty_value
430
- values[i] += prob_dist[i,next_token[i]]
431
- max_v = 0.0
432
- max_i = 0
433
- cnt = 0
434
- for test_input,test_repeat_penalty,value in zip(inputs,repeat_penalty,values):
435
- test_input = test_input.repeat(acc_n,1)
436
- test_repeat_penalty = test_repeat_penalty.repeat(acc_n,1)
437
- value += float(text_continue_value(
438
- model,test_input,deep_n,test_repeat_penalty,repeat_penalty_value,temperature,decay
439
- ))/(acc_n*deep_n)
440
- if value > max_v:
441
- max_v = value
442
- max_i = cnt
443
- cnt += 1
444
  return inputs[max_i]
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import time
6
+ import threading
7
+ import copy
8
+ class Batch:
9
+ def __init__(self,input_sequences):
10
+ self.data_type = "generator"
11
+ self.query = input_sequences[...,:-1]
12
+ self.label = input_sequences[...,1:]
13
+ self.q_mask = self.query != 0
14
+ self.ntokens = float((self.label != 0).sum())
15
+
16
+ #交叉熵损失,“0”填充特殊处理
17
+ class CrossEntropyLoss(nn.Module):
18
+ def __init__(self):
19
+ super(CrossEntropyLoss, self).__init__()
20
+ # 使用KL散度损失函数(接受对数概率分布x和概率分布y,并不是简单的KL散度计算)
21
+ self.criterion = nn.KLDivLoss(reduction='sum')
22
+
23
+ def forward(self, model_output_dist, target_sequence):
24
+ #根据模型输出的分布与标签的分布计算交叉熵损失
25
+ #为目标分布分配和模型输出形状、类型一样的空间,默认不追踪梯度,写明更清晰
26
+ true_dist = torch.zeros_like(model_output_dist,requires_grad=False)
27
+ #使用置信度填充目标词的位置(true_dist是词表那么长的概率分布)
28
+ #目标序列升维,target_sequence:[batch*len]->[batch*len,1]
29
+ #true_dist:[batch*len,vocab]
30
+ #在vocab的维度上用标签值当作索引,找到对应元素,填充1.0
31
+ true_dist.scatter_(1, target_sequence.data.unsqueeze(1), 1.0)
32
+ #将填充位置概率设为0
33
+ true_dist[:,0] = 0.0
34
+ #计算模型输出分布与平目标序列标签平滑后的分布之间的交叉熵
35
+ #model_output_dist是对数概率分布,应由F.log_softmax(self.project(x),dim=-1)产生
36
+ #但实际上为了压缩softmax的值域已达到自动丢弃异常值的效果,在Generator.Projector进行了特殊实现
37
+ return self.criterion(model_output_dist, true_dist)
38
+
39
+ class AdamOptimizerWithBase:
40
+ "带有Base的自适应矩估计优化器"
41
+ def __init__(self, params, base, half_life, betas, eps):
42
+ self.beta1 = betas[0]
43
+ self.beta2 = betas[1]
44
+ self.beta3 = (1/2)**(1/half_life)
45
+ self.epsilon = eps
46
+ self.t = 0
47
+ self.param_groups = []
48
+ for p,b in zip(params,base):
49
+ self.param_groups.append({
50
+ 'params': p,
51
+ 'lr' : 0.0,
52
+ 'm' : torch.zeros_like(p).detach(),
53
+ 'v' : torch.zeros_like(p).detach(),
54
+ 'b' : b.clone().detach()
55
+ })
56
+
57
+ def step(self):
58
+ self.t += 1
59
+ for group in self.param_groups:
60
+ # 获取梯度
61
+ grad = group['params'].grad
62
+ if grad is None:
63
+ continue
64
+ with torch.no_grad():
65
+ # 历史衰减
66
+ group['m'].mul_(self.beta1).add_(grad, alpha = 1 - self.beta1)
67
+ group['v'].mul_(self.beta2).addcmul_(grad, grad, value = 1 - self.beta2)
68
+ # 偏差纠正
69
+ m_hat = group['m'] / (1 - self.beta1 ** self.t)
70
+ v_hat = group['v'] / (1 - self.beta2 ** self.t)
71
+ # 参数更新
72
+ group['params'].sub_(group['lr'] / (v_hat.sqrt() + self.epsilon) * m_hat).mul_(self.beta3).add_(group['b'],alpha = 1 - self.beta3)
73
+
74
+ def zero_grad(self):
75
+ for group in self.param_groups:
76
+ if group['params'].grad is not None:
77
+ group['params'].grad.detach_()
78
+ group['params'].grad.zero_()
79
+
80
+ def refresh(self):
81
+ for group in self.param_groups:
82
+ group['m'] = torch.zeros_like(group['params']).detach()
83
+ group['v'] = torch.zeros_like(group['params']).detach()
84
+ group['b'] = group['params'].clone().detach()
85
+ self.t = 0
86
+
87
+ class SimpleAdamOptimizer:
88
+ "简单的自适应矩估计优化器"
89
+ def __init__(self, params, betas, eps):
90
+ self.beta1 = betas[0]
91
+ self.beta2 = betas[1]
92
+ self.epsilon = eps
93
+ self.t = 0
94
+ self.param_groups = []
95
+ for p in params:
96
+ self.param_groups.append({
97
+ 'params': p,
98
+ 'lr' : 0.0,
99
+ 'm' : torch.zeros_like(p).detach(),
100
+ 'v' : torch.zeros_like(p).detach()
101
+ })
102
+
103
+ def step(self):
104
+ self.t += 1
105
+ for group in self.param_groups:
106
+ # 获取梯度
107
+ grad = group['params'].grad
108
+ if grad is None:
109
+ continue
110
+ grad[grad!=grad] = 0.0
111
+ grad[grad>100] = 100.0
112
+ grad[grad<-100] = -100.0
113
+ with torch.no_grad():
114
+ # 历史衰减
115
+ group['m'].mul_(self.beta1).add_(grad, alpha = 1 - self.beta1)
116
+ group['v'].mul_(self.beta2).addcmul_(grad, grad, value = 1 - self.beta2)
117
+ # 偏差纠正
118
+ m_hat = group['m'] / (1 - self.beta1 ** self.t)
119
+ v_hat = group['v'] / (1 - self.beta2 ** self.t)
120
+ # 参数更新
121
+ group['params'].sub_(group['lr'] / (v_hat.sqrt() + self.epsilon) * m_hat)
122
+
123
+ def zero_grad(self):
124
+ for group in self.param_groups:
125
+ if group['params'].grad is not None:
126
+ group['params'].grad.detach_()
127
+ group['params'].grad.zero_()
128
+
129
+ def get_lrate(start_step,total_step,lr_from,lr_to,transition,enable_wave):
130
+ assert transition > 0 and transition % 2 == 0, "Need transition lt 0 and transition mod 2 eq 0."
131
+ mid_transition = transition // 2
132
+ half_lr_gap = (lr_to - lr_from)/2
133
+ if total_step >= start_step + transition:
134
+ ret = lr_to
135
+ elif total_step < start_step + mid_transition:
136
+ ret = lr_from + half_lr_gap * (total_step - start_step)**2 / mid_transition**2
137
+ else:
138
+ ret = lr_to - half_lr_gap * (start_step + transition - total_step)**2 / mid_transition**2
139
+ #最后的时候震荡,否则有危害
140
+ if ret != lr_to or enable_wave == False or lr_to > 2e-4:
141
+ return ret
142
+ else:
143
+ return ret + np.sin((total_step - start_step) * np.pi / mid_transition) * lr_to * 0.9
144
+
145
+ record = {
146
+ "loss_line" : [],
147
+ "lr_line" : []
148
+ }
149
+
150
+ class OptimizerWrapper:
151
+ def __init__(self, optimizer, warm_up, lr, enable_wave = False):
152
+ self.lr_from = 0 #初始学习率
153
+ self.lr_to = lr #目标学习率
154
+ self.warm_up = warm_up #预热步数
155
+ self.start_step= 0 #起始步数
156
+ self.total_step= 0 #总步数
157
+ self.optimizer = optimizer #优化器,用于执行梯度下降
158
+ self.enable_wave = enable_wave #学习率波动
159
+
160
+ def update(self):
161
+ global record
162
+ #设置优化器中每个参数组的学习率并执行梯度下降
163
+ lrate = self.lrate()
164
+ record["lr_line"] += [lrate]
165
+ for parameters in self.optimizer.param_groups:
166
+ parameters['lr'] = lrate
167
+ self.optimizer.step()
168
+ self.optimizer.zero_grad()
169
+
170
+ def lrate(self):
171
+ self.total_step += 1
172
+ return get_lrate(
173
+ self.start_step,
174
+ self.total_step,
175
+ self.lr_from,
176
+ self.lr_to,
177
+ self.warm_up,
178
+ self.enable_wave)
179
+
180
+ def set_lrate(self,lrate,transition):
181
+ self.lr_from = self.lr_to
182
+ self.lr_to = lrate
183
+ self.warm_up = transition
184
+ self.start_step = self.total_step
185
+
186
+ stop = False
187
+ pause = False
188
+
189
+ def run_epoch(model,data_iter,caculate_size,loss_f,optimizer,epoch,use_amp):
190
+ global stop
191
+ global pause
192
+ global record
193
+ for step, batch in enumerate(data_iter):
194
+ if stop:
195
+ break
196
+ while pause:
197
+ time.sleep(0.5)
198
+ total_loss = 0
199
+ t_start = time.time()
200
+ for i in range(0,batch.query.size(0),caculate_size):
201
+ if use_amp:
202
+ with torch.amp.autocast("cuda"):
203
+ model_output = model(batch.query[i:i+caculate_size], batch.q_mask[i:i+caculate_size])
204
+ loss = loss_f(torch.log(F.softmax(model_output,dim=-1).mul(0.99).add(5e-3)).view(-1,model_output.size(-1)),
205
+ batch.label[i:i+caculate_size].reshape(-1))/ batch.ntokens
206
+ loss.backward()
207
+ total_loss += float(loss) * batch.ntokens
208
+ else:
209
+ model_output = model(batch.query[i:i+caculate_size], batch.q_mask[i:i+caculate_size])
210
+ loss = loss_f(torch.log(F.softmax(model_output,dim=-1).mul(0.99).add(5e-3)).view(-1,model_output.size(-1)),
211
+ batch.label[i:i+caculate_size].reshape(-1))/ batch.ntokens
212
+ loss.backward()
213
+ total_loss += float(loss) * batch.ntokens
214
+ optimizer.update()
215
+ mean_loss = total_loss/batch.ntokens
216
+ record["loss_line"] += [mean_loss]
217
+ t_end = time.time()
218
+ print('\repoch:',epoch,'\tstep:',step,'\tloss:',str(mean_loss)[:5],'\tspeed:',str(batch.ntokens/(t_end - t_start))[:7],'tokens/s',end = ' '*20)
219
+
220
+
221
+ #训练函数以服务模式运行,可以随时手动调整
222
+ def train(model,data_generator,batch_size,caculate_size,loss_f,optimizer,use_amp):
223
+ global stop
224
+ epoch = 0
225
+ while(True):
226
+ if stop:
227
+ break
228
+ run_epoch(model,data_generator(batch_size),caculate_size,loss_f,optimizer,epoch,use_amp)
229
+ epoch += 1
230
+
231
+ #启动训练服务
232
+ def train_server_start(model,generator_batch_pair,split_n,loss_f,optimizer,use_amp = False):
233
+ assert generator_batch_pair[1] % split_n == 0, "Need batch_size mod split_n eq 0."
234
+ data_generator,batch_size = generator_batch_pair
235
+ thread = threading.Thread(target=train,args=(model,data_generator,batch_size,batch_size//split_n,loss_f,optimizer,use_amp))
236
+ thread.start()
237
+
238
+ def TOGGLE():
239
+ global pause
240
+ pause = not pause
241
+ print("pause:",pause)
242
+
243
+ def STOP():
244
+ global stop
245
+ stop = True
246
+
247
+ #贪婪解码
248
+ def greedy_decode(model,inputs,out_length):
249
+ if model.model_type == "generator":
250
+ for _ in range(out_length):
251
+ query = model.embedding(inputs)
252
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1:,:])
253
+ next_token = torch.max(prob_dist, dim = -1)[1]
254
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
255
+ return inputs
256
+
257
+ def El_greedy_decode(model,inputs,out_length):
258
+ if model.model_type == "generator":
259
+ assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
260
+ query = model.embedding(inputs)
261
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1:,:])
262
+ next_token = torch.max(prob_dist, dim = -1)[1]
263
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
264
+ for _ in range(0,out_length-1,1):
265
+ query = model.embedding(inputs[:,[-1]])
266
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1:,:])
267
+ next_token = torch.max(prob_dist, dim = -1)[1]
268
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
269
+ return inputs
270
+
271
+ #概率解码
272
+ def sampling_decode(model,inputs,out_length):
273
+ if model.model_type == "generator":
274
+ for _ in range(out_length):
275
+ query = model.embedding(inputs)
276
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
277
+ next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
278
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
279
+ return inputs
280
+
281
+ def El_sampling_decode(model,inputs,out_length):
282
+ if model.model_type == "generator":
283
+ assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
284
+ query = model.embedding(inputs)
285
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
286
+ next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
287
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
288
+ for _ in range(0,out_length-1,1):
289
+ query = model.embedding(inputs[:,[-1]])
290
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
291
+ next_token = torch.multinomial(F.softmax(prob_dist, dim = -1), num_samples = 1)
292
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
293
+ return inputs
294
+
295
+ #更可控的文本续写工具
296
+ def text_continue(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98):
297
+ if model.model_type == "generator":
298
+ repeat_penalty = None
299
+ for _ in range(out_length):
300
+ query = model.embedding(inputs)
301
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
302
+ if repeat_penalty is None:
303
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
304
+ for index in range(inputs.size(1)):
305
+ for line in range(inputs.size(0)):
306
+ repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
307
+ repeat_penalty *= decay
308
+ else:
309
+ repeat_penalty *= decay
310
+ prob_dist += repeat_penalty
311
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
312
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
313
+ for i in range(next_token.size(0)):
314
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
315
+ return inputs
316
+
317
+ def El_text_continue(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98):
318
+ if model.model_type == "generator":
319
+ assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
320
+ query = model.embedding(inputs)
321
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
322
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
323
+ for index in range(inputs.size(1)):
324
+ for line in range(inputs.size(0)):
325
+ repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
326
+ repeat_penalty *= decay
327
+ prob_dist += repeat_penalty
328
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
329
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
330
+ for i in range(next_token.size(0)):
331
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
332
+ for _ in range(0,out_length-1,1):
333
+ query = model.embedding(inputs[:,[-1]])
334
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
335
+ repeat_penalty *= decay
336
+ prob_dist += repeat_penalty
337
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
338
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
339
+ for i in range(next_token.size(0)):
340
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
341
+ return inputs
342
+
343
+ def El_text_continue_stream(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98,session_id='0'):
344
+ if model.model_type == "generator":
345
+ assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
346
+ query = model.embedding(inputs)
347
+ prob_dist = model.projector(model.encoder(query,inputs==inputs,session_id)[:,-1,:])
348
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
349
+ for index in range(inputs.size(1)):
350
+ for line in range(inputs.size(0)):
351
+ repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
352
+ repeat_penalty *= decay
353
+ prob_dist += repeat_penalty
354
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
355
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:]
356
+ yield inputs
357
+ for i in range(next_token.size(0)):
358
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
359
+ for _ in range(0,out_length-1,1):
360
+ query = model.embedding(inputs[:,[-1]])
361
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]],session_id)[:,-1,:])
362
+ repeat_penalty *= decay
363
+ prob_dist += repeat_penalty
364
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
365
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:] #留下最后4个字就足够了(utf-8最长是4字节)
366
+ for i in range(next_token.size(0)):
367
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
368
+ yield inputs
369
+
370
+ #值函数,给基于蒙特卡洛树的续写用
371
+ def text_continue_value(model,inputs,out_length,repeat_penalty,repeat_penalty_value,temperature,decay):
372
+ if model.model_type == "generator":
373
+ ret = 0
374
+ assert len(inputs[0]) > 1,"初始序列长度必须大于1,与增量续写进行区分"
375
+ query = model.embedding(inputs)
376
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
377
+ prob_dist += repeat_penalty
378
+ repeat_penalty *= decay
379
+ prob_dist = F.softmax(prob_dist/temperature, dim = -1)
380
+ next_token = torch.multinomial(prob_dist, num_samples = 1)
381
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
382
+ for i in range(next_token.size(0)):
383
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
384
+ ret += prob_dist[i,next_token[i]]
385
+ for _ in range(0,out_length-1,1):
386
+ query = model.embedding(inputs[:,[-1]])
387
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
388
+ prob_dist += repeat_penalty
389
+ repeat_penalty *= decay
390
+ prob_dist = F.softmax(prob_dist/temperature, dim = -1)
391
+ next_token = torch.multinomial(prob_dist, num_samples = 1)
392
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
393
+ for i in range(next_token.size(0)):
394
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
395
+ ret += prob_dist[i,next_token[i]]
396
+ return ret
397
+
398
+ #基于蒙特卡洛树的续写
399
+ def MC_continue(model,inputs,out_length,repeat_penalty_value,temperature,try_n,acc_n,deep_n,decay=0.98):
400
+ if model.model_type == "generator":
401
+ repeat_penalty = None
402
+ assert inputs.dim() == 1, "不支持并行续写!Need inputs.dim eq 1"
403
+ #复制多份进行树搜索
404
+ values = [0] * try_n
405
+ inputs = inputs.repeat(try_n,1)
406
+ query = model.embedding(inputs)
407
+ prob_dist = model.projector(model.encoder(query,inputs==inputs)[:,-1,:])
408
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
409
+ for index in range(inputs.size(1)):
410
+ for line in range(inputs.size(0)):
411
+ repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
412
+ repeat_penalty *= decay
413
+ prob_dist += repeat_penalty
414
+ prob_dist = F.softmax(prob_dist/temperature, dim = -1)
415
+ next_token = torch.multinomial(prob_dist, num_samples = 1)
416
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
417
+ for i in range(try_n):
418
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
419
+ values[i] += prob_dist[i,next_token[i]]
420
+ for cur in range(0,out_length-1,1):
421
+ query = model.embedding(inputs[:,[-1]])
422
+ prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]])[:,-1,:])
423
+ repeat_penalty *= decay
424
+ prob_dist += repeat_penalty
425
+ prob_dist = F.softmax(prob_dist/temperature, dim = -1)
426
+ next_token = torch.multinomial(prob_dist, num_samples = 1)
427
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)
428
+ for i in range(try_n):
429
+ repeat_penalty[i][next_token[i]] -= repeat_penalty_value
430
+ values[i] += prob_dist[i,next_token[i]]
431
+ max_v = 0.0
432
+ max_i = 0
433
+ cnt = 0
434
+ for test_input,test_repeat_penalty,value in zip(inputs,repeat_penalty,values):
435
+ test_input = test_input.repeat(acc_n,1)
436
+ test_repeat_penalty = test_repeat_penalty.repeat(acc_n,1)
437
+ value += float(text_continue_value(
438
+ model,test_input,deep_n,test_repeat_penalty,repeat_penalty_value,temperature,decay
439
+ ))/(acc_n*deep_n)
440
+ if value > max_v:
441
+ max_v = value
442
+ max_i = cnt
443
+ cnt += 1
444
  return inputs[max_i]