| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, AutoModel |
| from .configuration_viconbert import ViConBERTConfig |
|
|
|
|
| class MLPBlock(nn.Module): |
| def __init__(self, input_dim, hidden_dim, output_dim, |
| num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True): |
| super().__init__() |
| self.use_residual = use_residual |
| self.activation_fn = activation() |
|
|
| self.input_layer = nn.Linear(input_dim, hidden_dim) |
| self.hidden_layers = nn.ModuleList() |
| self.norms = nn.ModuleList() |
| self.dropouts = nn.ModuleList() |
| for _ in range(num_layers): |
| self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim)) |
| self.norms.append(nn.LayerNorm(hidden_dim)) |
| self.dropouts.append(nn.Dropout(dropout)) |
| self.output_layer = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x): |
| x = self.input_layer(x) |
| for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts): |
| residual = x |
| x = layer(x) |
| x = norm(x) |
| x = dropout(x) |
| x = self.activation_fn(x) |
| if self.use_residual: |
| x = x + residual |
| x = self.output_layer(x) |
| return x |
|
|
|
|
| class ViConBERT(PreTrainedModel): |
| config_class = ViConBERTConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.context_encoder = AutoModel.from_pretrained( |
| config.base_model, cache_dir=config.base_model_cache_dir |
| ) |
| self.context_projection = MLPBlock( |
| self.context_encoder.config.hidden_size, |
| config.hidden_dim, |
| config.out_dim, |
| dropout=config.dropout, |
| num_layers=config.num_layers |
| ) |
| self.context_attention = nn.MultiheadAttention( |
| self.context_encoder.config.hidden_size, |
| num_heads=config.num_head, |
| dropout=config.dropout |
| ) |
| self.context_window_size = config.context_window_size |
| self.context_layer_weights = nn.Parameter( |
| torch.zeros(self.context_encoder.config.num_hidden_layers) |
| ) |
| self.post_init() |
|
|
| def _encode_context_attentive(self, text, target_span): |
| outputs = self.context_encoder(**text) |
| hidden_states = outputs[0] |
| start_pos, end_pos = target_span[:, 0], target_span[:, 1] |
|
|
| positions = torch.arange(hidden_states.size(1), device=hidden_states.device) |
| mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1)) |
| masked_states = hidden_states * mask.unsqueeze(-1) |
| span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1) |
| pooled_embeddings = masked_states.sum(dim=1) / span_lengths |
|
|
| Q_value = pooled_embeddings.unsqueeze(0) |
| KV_value = hidden_states.permute(1, 0, 2) |
| context_emb, _ = self.context_attention(Q_value, KV_value, KV_value) |
| return context_emb |
|
|
| def forward(self, context, target_span): |
| context_emb = self._encode_context_attentive(context, target_span) |
| return self.context_projection(context_emb.squeeze(0)) |
|
|
|
|