Đinh Trác Đức Anh commited on
Commit
8679365
·
1 Parent(s): 50f5d14

create batch bias matrix

Browse files
Files changed (2) hide show
  1. bias_utils.py +44 -66
  2. model.py +31 -32
bias_utils.py CHANGED
@@ -3,73 +3,51 @@ import numpy as np
3
 
4
  def create_bias_matrix(bmes_tags, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0):
5
  """
6
- Tạo ma trận bias từ BMES tags theo quy tắc:
7
- - α (alpha): khi cả 2 token thuộc cùng 1 từ (cùng BMES sequence)
8
- - β (beta): khi 2 token thuộc 2 từ khác nhau
9
- - γ (gamma): khi có ít nhất 1 token là 'S' (single syllable)
10
- - δ (delta): khi token khớp với chính nó (diagonal)
11
-
12
- Args:
13
- bmes_tags: list hoặc tensor chứa BMES tags
14
- - Nếu là list: ['B', 'M', 'E', 'S', ...]
15
- - Nếu là tensor: [0, 1, 2, 3, ...] (B=0, M=1, E=2, S=3)
16
- alpha: trọng số cho cặp token cùng từ
17
- beta: trọng số cho cặp token khác từ
18
- gamma: trọng số khi có token 'S'
19
- delta: trọng số cho diagonal (token với chính nó)
20
-
21
- Returns:
22
- bias_matrix: ma trận bias có shape (seq_len, seq_len)
23
  """
24
- # Chuyển về list nếu là tensor
25
- if isinstance(bmes_tags, torch.Tensor):
26
- BMES_MAP_INV = {0: 'B', 1: 'M', 2: 'E', 3: 'S'}
27
- bmes_tags = [BMES_MAP_INV[t.item() if isinstance(t, torch.Tensor) else t]
28
- for t in bmes_tags.squeeze().tolist()]
29
-
30
- seq_len = len(bmes_tags)
31
- bias_matrix = np.zeros((seq_len, seq_len))
32
-
33
- # Xác định ranh giới từ: tìm các nhóm token thuộc cùng một từ
34
- word_groups = []
35
- current_group = [0]
36
-
37
- for i in range(1, seq_len):
38
- prev_tag = bmes_tags[i-1]
39
- curr_tag = bmes_tags[i]
40
-
41
- # Nếu token trước là 'E' hoặc 'S', thì từ mới bắt đầu
42
- if prev_tag in ['E', 'S']:
43
- word_groups.append(current_group)
44
- current_group = [i]
45
- else:
46
- current_group.append(i)
47
-
48
- # Thêm group cuối cùng
49
- if current_group:
50
- word_groups.append(current_group)
51
-
52
- # Điền giá trị vào ma trận
53
- for i in range(seq_len):
54
- for j in range(seq_len):
55
- if i == j:
56
- # Diagonal
57
- bias_matrix[i, j] = delta
58
- elif bmes_tags[i] == 'S' or bmes_tags[j] == 'S':
59
- # Có ít nhất 1 token là 'S'
60
- bias_matrix[i, j] = gamma
61
  else:
62
- # Kiểm tra xem i và j có thuộc cùng từ không
63
- same_word = False
64
- for group in word_groups:
65
- if i in group and j in group:
66
- same_word = True
67
- break
68
 
69
- if same_word:
70
- bias_matrix[i, j] = alpha
 
 
 
 
 
71
  else:
72
- bias_matrix[i, j] = beta
73
-
74
- return bias_matrix
75
-
 
 
 
 
 
 
 
 
3
 
4
  def create_bias_matrix(bmes_tags, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0):
5
  """
6
+ Hỗ trợ:
7
+ - bmes_tags: shape [seq_len] (1 sample) hoặc [B, seq_len] (batch)
8
+ Trả về bias_matrix:
9
+ - 1 sample: [seq_len, seq_len]
10
+ - batch: [B, seq_len, seq_len]
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
+ def single_bias(seq_tags):
13
+ # Chuyển tensor -> list ['B','M','E','S']
14
+ if isinstance(seq_tags, torch.Tensor):
15
+ BMES_MAP_INV = {0:'B',1:'M',2:'E',3:'S'}
16
+ seq_tags = [BMES_MAP_INV[t.item()] if isinstance(t, torch.Tensor) else BMES_MAP_INV[t] for t in seq_tags.tolist()]
17
+
18
+ seq_len = len(seq_tags)
19
+ bias_matrix = np.zeros((seq_len, seq_len))
20
+
21
+ # Nhóm token theo từ
22
+ word_groups = []
23
+ current_group = [0]
24
+ for i in range(1, seq_len):
25
+ prev_tag = seq_tags[i-1]
26
+ curr_tag = seq_tags[i]
27
+ if prev_tag in ['E','S']:
28
+ word_groups.append(current_group)
29
+ current_group = [i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  else:
31
+ current_group.append(i)
32
+ if current_group:
33
+ word_groups.append(current_group)
 
 
 
34
 
35
+ # Điền bias
36
+ for i in range(seq_len):
37
+ for j in range(seq_len):
38
+ if i == j:
39
+ bias_matrix[i,j] = delta
40
+ elif seq_tags[i]=='S' or seq_tags[j]=='S':
41
+ bias_matrix[i,j] = gamma
42
  else:
43
+ same_word = any(i in g and j in g for g in word_groups)
44
+ bias_matrix[i,j] = alpha if same_word else beta
45
+ return bias_matrix
46
+
47
+ if isinstance(bmes_tags, torch.Tensor) and bmes_tags.dim() == 2:
48
+ # batch
49
+ batch_bias = [single_bias(bmes_tags[i]) for i in range(bmes_tags.size(0))]
50
+ return np.stack(batch_bias, axis=0) # [B, seq_len, seq_len]
51
+ else:
52
+ # 1 sample
53
+ return single_bias(bmes_tags) # [seq_len, seq_len]
model.py CHANGED
@@ -7,64 +7,63 @@ class MorphemeAwareRobertaModel(RobertaModel):
7
  """
8
  PhoBERT mở rộng với:
9
  - BoundaryAwareEmbeddings (BMES + gate)
10
- - BMES bias hook trên attention head
11
  """
12
 
13
  def __init__(self, config, target_heads=None, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0, **kwargs):
14
- """
15
- config: RobertaConfig hoặc đường dẫn HF pretrained
16
- target_heads: dict[layer_idx] = list[head_idx] để apply bias
17
- alpha,beta,gamma,delta: weight của BMES bias
18
- kwargs: giữ để HF from_pretrained gọi được
19
- """
20
  super().__init__(config, **kwargs)
21
 
22
- # Thay embedding gốc bằng embedding mới
23
  self.embeddings = BoundaryAwareEmbeddings(config, **kwargs)
24
 
25
- # Thông số bias
26
  self.target_heads = target_heads or {}
27
  self.alpha = alpha
28
  self.beta = beta
29
  self.gamma = gamma
30
  self.delta = delta
31
 
32
- # Tokenizer (set ngoài)
33
  self.tokenizer = None
34
-
35
- # Lưu hook và bias_matrix
36
  self.bias_hooks = {}
37
- self.bias_matrix = None
38
 
39
  def set_tokenizer(self, tokenizer):
40
- assert tokenizer is not None, "Tokenizer không được để None"
41
  self.tokenizer = tokenizer
42
 
43
  def set_bias_matrix(self, bmes_tags):
44
- bias = create_bias_matrix(
45
- bmes_tags.squeeze(0) if isinstance(bmes_tags, torch.Tensor) else bmes_tags,
46
- alpha=self.alpha, beta=self.beta, gamma=self.gamma, delta=self.delta
47
- )
48
- bias = torch.tensor(bias, dtype=torch.float32).unsqueeze(0)
 
 
 
 
 
 
 
49
  num_heads = self.config.num_attention_heads
50
- bias = bias.unsqueeze(1).repeat(1, num_heads, 1, 1)
51
- self.bias_matrix = bias.to(next(self.parameters()).device)
52
 
53
  def _register_attention_hook(self, layer_idx, head_indices):
54
  def hook_fn(module, input, output):
55
- if not isinstance(output, tuple) or len(output) < 2:
 
56
  return output
57
- context_layer, attention_probs = output
58
- if attention_probs is None or self.bias_matrix is None:
59
  return output
 
60
  bias = self.bias_matrix
61
- seq_len = attention_probs.size(-1)
62
- if bias.size(-1) != seq_len:
63
- bias = bias[:, :, :seq_len, :seq_len]
64
  for h in head_indices:
65
- if h < attention_probs.size(1):
66
- attention_probs[:, h, :, :] += bias[:, h, :, :]
67
- return (context_layer, attention_probs)
68
 
69
  attn = self.encoder.layer[layer_idx].attention.self
70
  hook = attn.register_forward_hook(hook_fn)
@@ -103,7 +102,7 @@ class MorphemeAwareRobertaModel(RobertaModel):
103
  if self.target_heads:
104
  self.prepare_bias_hooks()
105
 
106
- output_attentions = True
107
 
108
  outputs = super().forward(
109
  input_ids=input_ids,
@@ -118,4 +117,4 @@ class MorphemeAwareRobertaModel(RobertaModel):
118
  )
119
 
120
  self.remove_bias_hooks()
121
- return outputs
 
7
  """
8
  PhoBERT mở rộng với:
9
  - BoundaryAwareEmbeddings (BMES + gate)
10
+ - BMES bias hook trên attention head, hỗ trợ batch
11
  """
12
 
13
  def __init__(self, config, target_heads=None, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0, **kwargs):
 
 
 
 
 
 
14
  super().__init__(config, **kwargs)
15
 
16
+ # Embedding mới
17
  self.embeddings = BoundaryAwareEmbeddings(config, **kwargs)
18
 
19
+ # Bias params
20
  self.target_heads = target_heads or {}
21
  self.alpha = alpha
22
  self.beta = beta
23
  self.gamma = gamma
24
  self.delta = delta
25
 
 
26
  self.tokenizer = None
 
 
27
  self.bias_hooks = {}
28
+ self.bias_matrix = None # shape: [B, num_heads, seq_len, seq_len]
29
 
30
  def set_tokenizer(self, tokenizer):
31
+ assert tokenizer is not None
32
  self.tokenizer = tokenizer
33
 
34
  def set_bias_matrix(self, bmes_tags):
35
+ """
36
+ bmes_tags: tensor [B, seq_len] hoặc [seq_len]
37
+ Trả về tensor [B, num_heads, seq_len, seq_len]
38
+ """
39
+ if isinstance(bmes_tags, torch.Tensor) and bmes_tags.dim() == 1:
40
+ # 1 sample -> add batch dim
41
+ bmes_tags = bmes_tags.unsqueeze(0)
42
+
43
+ batch_size, seq_len = bmes_tags.shape
44
+ bias_np = create_bias_matrix(bmes_tags, alpha=self.alpha, beta=self.beta, gamma=self.gamma, delta=self.delta)
45
+ bias_tensor = torch.tensor(bias_np, dtype=torch.float32, device=next(self.parameters()).device)
46
+ # lặp num_heads
47
  num_heads = self.config.num_attention_heads
48
+ bias_tensor = bias_tensor.unsqueeze(1).repeat(1, num_heads, 1, 1) # [B, num_heads, seq_len, seq_len]
49
+ self.bias_matrix = bias_tensor
50
 
51
  def _register_attention_hook(self, layer_idx, head_indices):
52
  def hook_fn(module, input, output):
53
+ # output: (context_layer, attention_probs)
54
+ if not isinstance(output, tuple) or len(output)<2:
55
  return output
56
+ context_layer, attn_probs = output
57
+ if attn_probs is None or self.bias_matrix is None:
58
  return output
59
+ B, H, L, _ = attn_probs.shape
60
  bias = self.bias_matrix
61
+ if bias.size(-1) != L:
62
+ bias = bias[:, :, :L, :L]
 
63
  for h in head_indices:
64
+ if h < H:
65
+ attn_probs[:, h, :, :] += bias[:, h, :, :]
66
+ return (context_layer, attn_probs)
67
 
68
  attn = self.encoder.layer[layer_idx].attention.self
69
  hook = attn.register_forward_hook(hook_fn)
 
102
  if self.target_heads:
103
  self.prepare_bias_hooks()
104
 
105
+ output_attentions = True if output_attentions is None else output_attentions
106
 
107
  outputs = super().forward(
108
  input_ids=input_ids,
 
117
  )
118
 
119
  self.remove_bias_hooks()
120
+ return outputs