Đinh Trác Đức Anh commited on
Commit ·
8679365
1
Parent(s): 50f5d14
create batch bias matrix
Browse files- bias_utils.py +44 -66
- model.py +31 -32
bias_utils.py
CHANGED
|
@@ -3,73 +3,51 @@ import numpy as np
|
|
| 3 |
|
| 4 |
def create_bias_matrix(bmes_tags, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0):
|
| 5 |
"""
|
| 6 |
-
|
| 7 |
-
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
bmes_tags: list hoặc tensor chứa BMES tags
|
| 14 |
-
- Nếu là list: ['B', 'M', 'E', 'S', ...]
|
| 15 |
-
- Nếu là tensor: [0, 1, 2, 3, ...] (B=0, M=1, E=2, S=3)
|
| 16 |
-
alpha: trọng số cho cặp token cùng từ
|
| 17 |
-
beta: trọng số cho cặp token khác từ
|
| 18 |
-
gamma: trọng số khi có token 'S'
|
| 19 |
-
delta: trọng số cho diagonal (token với chính nó)
|
| 20 |
-
|
| 21 |
-
Returns:
|
| 22 |
-
bias_matrix: ma trận bias có shape (seq_len, seq_len)
|
| 23 |
"""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
if prev_tag in ['E', 'S']:
|
| 43 |
-
word_groups.append(current_group)
|
| 44 |
-
current_group = [i]
|
| 45 |
-
else:
|
| 46 |
-
current_group.append(i)
|
| 47 |
-
|
| 48 |
-
# Thêm group cuối cùng
|
| 49 |
-
if current_group:
|
| 50 |
-
word_groups.append(current_group)
|
| 51 |
-
|
| 52 |
-
# Điền giá trị vào ma trận
|
| 53 |
-
for i in range(seq_len):
|
| 54 |
-
for j in range(seq_len):
|
| 55 |
-
if i == j:
|
| 56 |
-
# Diagonal
|
| 57 |
-
bias_matrix[i, j] = delta
|
| 58 |
-
elif bmes_tags[i] == 'S' or bmes_tags[j] == 'S':
|
| 59 |
-
# Có ít nhất 1 token là 'S'
|
| 60 |
-
bias_matrix[i, j] = gamma
|
| 61 |
else:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
if i in group and j in group:
|
| 66 |
-
same_word = True
|
| 67 |
-
break
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def create_bias_matrix(bmes_tags, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0):
|
| 5 |
"""
|
| 6 |
+
Hỗ trợ:
|
| 7 |
+
- bmes_tags: shape [seq_len] (1 sample) hoặc [B, seq_len] (batch)
|
| 8 |
+
Trả về bias_matrix:
|
| 9 |
+
- 1 sample: [seq_len, seq_len]
|
| 10 |
+
- batch: [B, seq_len, seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
+
def single_bias(seq_tags):
|
| 13 |
+
# Chuyển tensor -> list ['B','M','E','S']
|
| 14 |
+
if isinstance(seq_tags, torch.Tensor):
|
| 15 |
+
BMES_MAP_INV = {0:'B',1:'M',2:'E',3:'S'}
|
| 16 |
+
seq_tags = [BMES_MAP_INV[t.item()] if isinstance(t, torch.Tensor) else BMES_MAP_INV[t] for t in seq_tags.tolist()]
|
| 17 |
+
|
| 18 |
+
seq_len = len(seq_tags)
|
| 19 |
+
bias_matrix = np.zeros((seq_len, seq_len))
|
| 20 |
+
|
| 21 |
+
# Nhóm token theo từ
|
| 22 |
+
word_groups = []
|
| 23 |
+
current_group = [0]
|
| 24 |
+
for i in range(1, seq_len):
|
| 25 |
+
prev_tag = seq_tags[i-1]
|
| 26 |
+
curr_tag = seq_tags[i]
|
| 27 |
+
if prev_tag in ['E','S']:
|
| 28 |
+
word_groups.append(current_group)
|
| 29 |
+
current_group = [i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
else:
|
| 31 |
+
current_group.append(i)
|
| 32 |
+
if current_group:
|
| 33 |
+
word_groups.append(current_group)
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# Điền bias
|
| 36 |
+
for i in range(seq_len):
|
| 37 |
+
for j in range(seq_len):
|
| 38 |
+
if i == j:
|
| 39 |
+
bias_matrix[i,j] = delta
|
| 40 |
+
elif seq_tags[i]=='S' or seq_tags[j]=='S':
|
| 41 |
+
bias_matrix[i,j] = gamma
|
| 42 |
else:
|
| 43 |
+
same_word = any(i in g and j in g for g in word_groups)
|
| 44 |
+
bias_matrix[i,j] = alpha if same_word else beta
|
| 45 |
+
return bias_matrix
|
| 46 |
+
|
| 47 |
+
if isinstance(bmes_tags, torch.Tensor) and bmes_tags.dim() == 2:
|
| 48 |
+
# batch
|
| 49 |
+
batch_bias = [single_bias(bmes_tags[i]) for i in range(bmes_tags.size(0))]
|
| 50 |
+
return np.stack(batch_bias, axis=0) # [B, seq_len, seq_len]
|
| 51 |
+
else:
|
| 52 |
+
# 1 sample
|
| 53 |
+
return single_bias(bmes_tags) # [seq_len, seq_len]
|
model.py
CHANGED
|
@@ -7,64 +7,63 @@ class MorphemeAwareRobertaModel(RobertaModel):
|
|
| 7 |
"""
|
| 8 |
PhoBERT mở rộng với:
|
| 9 |
- BoundaryAwareEmbeddings (BMES + gate)
|
| 10 |
-
- BMES bias hook trên attention head
|
| 11 |
"""
|
| 12 |
|
| 13 |
def __init__(self, config, target_heads=None, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0, **kwargs):
|
| 14 |
-
"""
|
| 15 |
-
config: RobertaConfig hoặc đường dẫn HF pretrained
|
| 16 |
-
target_heads: dict[layer_idx] = list[head_idx] để apply bias
|
| 17 |
-
alpha,beta,gamma,delta: weight của BMES bias
|
| 18 |
-
kwargs: giữ để HF from_pretrained gọi được
|
| 19 |
-
"""
|
| 20 |
super().__init__(config, **kwargs)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
self.embeddings = BoundaryAwareEmbeddings(config, **kwargs)
|
| 24 |
|
| 25 |
-
#
|
| 26 |
self.target_heads = target_heads or {}
|
| 27 |
self.alpha = alpha
|
| 28 |
self.beta = beta
|
| 29 |
self.gamma = gamma
|
| 30 |
self.delta = delta
|
| 31 |
|
| 32 |
-
# Tokenizer (set ngoài)
|
| 33 |
self.tokenizer = None
|
| 34 |
-
|
| 35 |
-
# Lưu hook và bias_matrix
|
| 36 |
self.bias_hooks = {}
|
| 37 |
-
self.bias_matrix = None
|
| 38 |
|
| 39 |
def set_tokenizer(self, tokenizer):
|
| 40 |
-
assert tokenizer is not None
|
| 41 |
self.tokenizer = tokenizer
|
| 42 |
|
| 43 |
def set_bias_matrix(self, bmes_tags):
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
num_heads = self.config.num_attention_heads
|
| 50 |
-
|
| 51 |
-
self.bias_matrix =
|
| 52 |
|
| 53 |
def _register_attention_hook(self, layer_idx, head_indices):
|
| 54 |
def hook_fn(module, input, output):
|
| 55 |
-
|
|
|
|
| 56 |
return output
|
| 57 |
-
context_layer,
|
| 58 |
-
if
|
| 59 |
return output
|
|
|
|
| 60 |
bias = self.bias_matrix
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
bias = bias[:, :, :seq_len, :seq_len]
|
| 64 |
for h in head_indices:
|
| 65 |
-
if h <
|
| 66 |
-
|
| 67 |
-
return (context_layer,
|
| 68 |
|
| 69 |
attn = self.encoder.layer[layer_idx].attention.self
|
| 70 |
hook = attn.register_forward_hook(hook_fn)
|
|
@@ -103,7 +102,7 @@ class MorphemeAwareRobertaModel(RobertaModel):
|
|
| 103 |
if self.target_heads:
|
| 104 |
self.prepare_bias_hooks()
|
| 105 |
|
| 106 |
-
output_attentions = True
|
| 107 |
|
| 108 |
outputs = super().forward(
|
| 109 |
input_ids=input_ids,
|
|
@@ -118,4 +117,4 @@ class MorphemeAwareRobertaModel(RobertaModel):
|
|
| 118 |
)
|
| 119 |
|
| 120 |
self.remove_bias_hooks()
|
| 121 |
-
return outputs
|
|
|
|
| 7 |
"""
|
| 8 |
PhoBERT mở rộng với:
|
| 9 |
- BoundaryAwareEmbeddings (BMES + gate)
|
| 10 |
+
- BMES bias hook trên attention head, hỗ trợ batch
|
| 11 |
"""
|
| 12 |
|
| 13 |
def __init__(self, config, target_heads=None, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
super().__init__(config, **kwargs)
|
| 15 |
|
| 16 |
+
# Embedding mới
|
| 17 |
self.embeddings = BoundaryAwareEmbeddings(config, **kwargs)
|
| 18 |
|
| 19 |
+
# Bias params
|
| 20 |
self.target_heads = target_heads or {}
|
| 21 |
self.alpha = alpha
|
| 22 |
self.beta = beta
|
| 23 |
self.gamma = gamma
|
| 24 |
self.delta = delta
|
| 25 |
|
|
|
|
| 26 |
self.tokenizer = None
|
|
|
|
|
|
|
| 27 |
self.bias_hooks = {}
|
| 28 |
+
self.bias_matrix = None # shape: [B, num_heads, seq_len, seq_len]
|
| 29 |
|
| 30 |
def set_tokenizer(self, tokenizer):
|
| 31 |
+
assert tokenizer is not None
|
| 32 |
self.tokenizer = tokenizer
|
| 33 |
|
| 34 |
def set_bias_matrix(self, bmes_tags):
|
| 35 |
+
"""
|
| 36 |
+
bmes_tags: tensor [B, seq_len] hoặc [seq_len]
|
| 37 |
+
Trả về tensor [B, num_heads, seq_len, seq_len]
|
| 38 |
+
"""
|
| 39 |
+
if isinstance(bmes_tags, torch.Tensor) and bmes_tags.dim() == 1:
|
| 40 |
+
# 1 sample -> add batch dim
|
| 41 |
+
bmes_tags = bmes_tags.unsqueeze(0)
|
| 42 |
+
|
| 43 |
+
batch_size, seq_len = bmes_tags.shape
|
| 44 |
+
bias_np = create_bias_matrix(bmes_tags, alpha=self.alpha, beta=self.beta, gamma=self.gamma, delta=self.delta)
|
| 45 |
+
bias_tensor = torch.tensor(bias_np, dtype=torch.float32, device=next(self.parameters()).device)
|
| 46 |
+
# lặp num_heads
|
| 47 |
num_heads = self.config.num_attention_heads
|
| 48 |
+
bias_tensor = bias_tensor.unsqueeze(1).repeat(1, num_heads, 1, 1) # [B, num_heads, seq_len, seq_len]
|
| 49 |
+
self.bias_matrix = bias_tensor
|
| 50 |
|
| 51 |
def _register_attention_hook(self, layer_idx, head_indices):
|
| 52 |
def hook_fn(module, input, output):
|
| 53 |
+
# output: (context_layer, attention_probs)
|
| 54 |
+
if not isinstance(output, tuple) or len(output)<2:
|
| 55 |
return output
|
| 56 |
+
context_layer, attn_probs = output
|
| 57 |
+
if attn_probs is None or self.bias_matrix is None:
|
| 58 |
return output
|
| 59 |
+
B, H, L, _ = attn_probs.shape
|
| 60 |
bias = self.bias_matrix
|
| 61 |
+
if bias.size(-1) != L:
|
| 62 |
+
bias = bias[:, :, :L, :L]
|
|
|
|
| 63 |
for h in head_indices:
|
| 64 |
+
if h < H:
|
| 65 |
+
attn_probs[:, h, :, :] += bias[:, h, :, :]
|
| 66 |
+
return (context_layer, attn_probs)
|
| 67 |
|
| 68 |
attn = self.encoder.layer[layer_idx].attention.self
|
| 69 |
hook = attn.register_forward_hook(hook_fn)
|
|
|
|
| 102 |
if self.target_heads:
|
| 103 |
self.prepare_bias_hooks()
|
| 104 |
|
| 105 |
+
output_attentions = True if output_attentions is None else output_attentions
|
| 106 |
|
| 107 |
outputs = super().forward(
|
| 108 |
input_ids=input_ids,
|
|
|
|
| 117 |
)
|
| 118 |
|
| 119 |
self.remove_bias_hooks()
|
| 120 |
+
return outputs
|