tkhangg0910 commited on
Commit
7a1efb6
·
verified ·
1 Parent(s): 9159113

Create modeling_viconbert.py

Browse files
Files changed (1) hide show
  1. modeling_viconbert.py +82 -0
modeling_viconbert.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoModel
4
+ from .configuration_viconbert import ViConBERTConfig
5
+
6
+
7
+ class MLPBlock(nn.Module):
8
+ def __init__(self, input_dim, hidden_dim, output_dim,
9
+ num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True):
10
+ super().__init__()
11
+ self.use_residual = use_residual
12
+ self.activation_fn = activation()
13
+
14
+ self.input_layer = nn.Linear(input_dim, hidden_dim)
15
+ self.hidden_layers = nn.ModuleList()
16
+ self.norms = nn.ModuleList()
17
+ self.dropouts = nn.ModuleList()
18
+ for _ in range(num_layers):
19
+ self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim))
20
+ self.norms.append(nn.LayerNorm(hidden_dim))
21
+ self.dropouts.append(nn.Dropout(dropout))
22
+ self.output_layer = nn.Linear(hidden_dim, output_dim)
23
+
24
+ def forward(self, x):
25
+ x = self.input_layer(x)
26
+ for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts):
27
+ residual = x
28
+ x = layer(x)
29
+ x = norm(x)
30
+ x = dropout(x)
31
+ x = self.activation_fn(x)
32
+ if self.use_residual:
33
+ x = x + residual
34
+ x = self.output_layer(x)
35
+ return x
36
+
37
+
38
+ class ViConBERT(PreTrainedModel):
39
+ config_class = ViConBERTConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.context_encoder = AutoModel.from_pretrained(
44
+ config.base_model, cache_dir=config.base_model_cache_dir
45
+ )
46
+ self.context_projection = MLPBlock(
47
+ self.context_encoder.config.hidden_size,
48
+ config.hidden_dim,
49
+ config.out_dim,
50
+ dropout=config.dropout,
51
+ num_layers=config.num_layers
52
+ )
53
+ self.context_attention = nn.MultiheadAttention(
54
+ self.context_encoder.config.hidden_size,
55
+ num_heads=config.num_head,
56
+ dropout=config.dropout
57
+ )
58
+ self.context_window_size = config.context_window_size
59
+ self.context_layer_weights = nn.Parameter(
60
+ torch.zeros(self.context_encoder.config.num_hidden_layers)
61
+ )
62
+ self.post_init()
63
+
64
+ def _encode_context_attentive(self, text, target_span):
65
+ outputs = self.context_encoder(**text)
66
+ hidden_states = outputs[0]
67
+ start_pos, end_pos = target_span[:, 0], target_span[:, 1]
68
+
69
+ positions = torch.arange(hidden_states.size(1), device=hidden_states.device)
70
+ mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1))
71
+ masked_states = hidden_states * mask.unsqueeze(-1)
72
+ span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1)
73
+ pooled_embeddings = masked_states.sum(dim=1) / span_lengths
74
+
75
+ Q_value = pooled_embeddings.unsqueeze(0)
76
+ KV_value = hidden_states.permute(1, 0, 2)
77
+ context_emb, _ = self.context_attention(Q_value, KV_value, KV_value)
78
+ return context_emb
79
+
80
+ def forward(self, context, target_span):
81
+ context_emb = self._encode_context_attentive(context, target_span)
82
+ return self.context_projection(context_emb.squeeze(0))