M_AHuTieuBert / model.py
Đinh Trác Đức Anh
sync model
3adf4a4
import torch
from transformers.models.roberta.modeling_roberta import RobertaModel
from .embeddings import BoundaryAwareEmbeddings
from .bias_utils import create_bias_matrix
class MorphemeAwareRobertaModel(RobertaModel):
"""
PhoBERT mở rộng với:
- BoundaryAwareEmbeddings (BMES + gate)
- BMES bias hook trên attention head, hỗ trợ batch
"""
def __init__(self, config, target_heads=None, alpha=0.1, beta=-0.05, gamma=0.0, delta=0.0, **kwargs):
super().__init__(config, **kwargs)
# ✅ OVERRIDE embedding mới - đây là điểm quan trọng
self.embeddings = BoundaryAwareEmbeddings(config, **kwargs)
# Bias params
self.target_heads = target_heads or {}
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.delta = delta
self.tokenizer = None
self.patched_forwards = {}
self.bias_matrix = None
def set_tokenizer(self, tokenizer):
assert tokenizer is not None
self.tokenizer = tokenizer
def set_bias_matrix(self, bmes_tags):
"""
bmes_tags: tensor [B, seq_len] hoặc [seq_len]
Trả về tensor [B, num_heads, seq_len, seq_len]
"""
if isinstance(bmes_tags, torch.Tensor) and bmes_tags.dim() == 1:
bmes_tags = bmes_tags.unsqueeze(0)
batch_size, seq_len = bmes_tags.shape
bias_np = create_bias_matrix(bmes_tags, alpha=self.alpha, beta=self.beta, gamma=self.gamma, delta=self.delta)
bias_tensor = torch.tensor(bias_np, dtype=torch.float32, device=next(self.parameters()).device)
num_heads = self.config.num_attention_heads
bias_tensor = bias_tensor.unsqueeze(1).repeat(1, num_heads, 1, 1)
self.bias_matrix = bias_tensor
def _create_patched_forward(self, layer_idx, head_indices, original_forward, attn_module):
"""
Tạo forward function mới có cộng bias vào attention scores trước softmax
"""
def patched_forward(
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
**kwargs
):
batch_size, seq_length = hidden_states.shape[:2]
# Tính Q, K, V
query_layer = attn_module.query(hidden_states)
# Xử lý key và value
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = attn_module.key(encoder_hidden_states)
value_layer = attn_module.value(encoder_hidden_states)
elif past_key_value is not None:
key_layer = attn_module.key(hidden_states)
value_layer = attn_module.value(hidden_states)
key_layer = torch.cat([past_key_value[0], key_layer], dim=1)
value_layer = torch.cat([past_key_value[1], value_layer], dim=1)
else:
key_layer = attn_module.key(hidden_states)
value_layer = attn_module.value(hidden_states)
# Reshape để split heads
def split_heads(tensor, num_heads, head_dim):
new_shape = tensor.size()[:-1] + (num_heads, head_dim)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3)
num_heads = attn_module.num_attention_heads
head_dim = attn_module.attention_head_size
query_layer = split_heads(query_layer, num_heads, head_dim)
key_layer = split_heads(key_layer, num_heads, head_dim)
value_layer = split_heads(value_layer, num_heads, head_dim)
if hasattr(attn_module, 'is_decoder') and attn_module.is_decoder:
past_key_value = (key_layer, value_layer)
# Tính attention scores
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / torch.sqrt(
torch.tensor(head_dim, dtype=attention_scores.dtype, device=attention_scores.device)
)
# ✅ CỘNG BIAS VÀO ĐÂY - TRƯỚC SOFTMAX
if self.bias_matrix is not None:
B, H, L, _ = attention_scores.shape
bias = self.bias_matrix
if bias.size(0) != B:
bias = bias[:B]
if bias.size(-1) != L:
bias = bias[:, :, :L, :L]
for h in head_indices:
if h < H:
attention_scores[:, h, :, :] = attention_scores[:, h, :, :] + bias[:, h, :, :]
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = attn_module.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# Tính context layer
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if hasattr(attn_module, 'is_decoder') and attn_module.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return patched_forward
def _patch_attention_layer(self, layer_idx, head_indices):
"""
Monkey patch forward method của attention layer
"""
attn_module = self.encoder.layer[layer_idx].attention.self
if layer_idx not in self.patched_forwards:
original_forward = attn_module.forward
self.patched_forwards[layer_idx] = (attn_module, original_forward)
patched_forward = self._create_patched_forward(
layer_idx, head_indices, original_forward, attn_module
)
attn_module.forward = patched_forward
def prepare_bias_patches(self):
"""
Patch tất cả các layer có target heads
"""
self.remove_bias_patches()
for layer_idx, heads in self.target_heads.items():
self._patch_attention_layer(layer_idx, heads)
def remove_bias_patches(self):
"""
Khôi phục lại original forward methods
"""
for layer_idx, (attn_module, original_forward) in self.patched_forwards.items():
attn_module.forward = original_forward
self.patched_forwards = {}
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
bmes_ids=None,
bmes_tags=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
past_key_values_length=0, # ✅ Thêm param này cho embedding layer
):
# Xử lý bmes_ids/bmes_tags
if bmes_ids is None and bmes_tags is not None:
bmes_ids = bmes_tags
if inputs_embeds is None:
inputs_embeds = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
bmes_ids=bmes_ids, # ✅ Truyền bmes_ids vào embedding
past_key_values_length=past_key_values_length
)
# Set bias matrix nếu có bmes_ids
if bmes_ids is not None:
self.set_bias_matrix(bmes_ids)
# Patch attention layers nếu có target heads
if self.target_heads:
self.prepare_bias_patches()
output_attentions = True if output_attentions is None else output_attentions
# ✅ Gọi parent forward NHƯNG truyền inputs_embeds thay vì input_ids
outputs = super().forward(
input_ids=None, # ✅ Set None vì đã có inputs_embeds
attention_mask=attention_mask,
token_type_ids=None, # ✅ Set None vì đã được xử lý trong embedding
position_ids=None, # ✅ Set None vì đã được xử lý trong embedding
head_mask=head_mask,
inputs_embeds=inputs_embeds, # ✅ Dùng embedding đã tính
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Cleanup patches
self.remove_bias_patches()
return outputs