KhaldiAbderrhmane's picture
Upload model
4975ae2 verified
from transformers import PreTrainedModel, AutoModel
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GATConv
from .config import BERTMultiGATAttentionConfig
class MultiHeadGATAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.03):
super(MultiHeadGATAttention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, hidden_size)
self.gat = GATConv(hidden_size, hidden_size, heads=num_heads, concat=False)
self.alpha = nn.Parameter(torch.tensor(0.5)) # Learnable weight for combining attention outputs
self.layer_norm_q = nn.LayerNorm(hidden_size)
self.layer_norm_k = nn.LayerNorm(hidden_size)
self.layer_norm_v = nn.LayerNorm(hidden_size)
self.layer_norm_out = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, edge_index):
batch_size = query.size(0)
seq_length = query.size(1)
query_orig = query
query = self.layer_norm_q(self.query(query))
key = self.layer_norm_k(self.key(key))
value = self.layer_norm_v(self.value(value))
query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.dropout(attention_weights)
attended_values_std = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
attended_values_std = attended_values_std.view(batch_size, seq_length, self.hidden_size)
query_gat = query.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
value_gat = value.permute(0, 2, 1, 3).reshape(batch_size * seq_length, self.hidden_size)
attended_values_gat = self.gat(value_gat, edge_index).view(batch_size, seq_length, self.hidden_size)
# Weighted combin
attended_values = self.alpha * attended_values_std + (1 - self.alpha) * attended_values_gat
attended_values = self.layer_norm_out(self.out(attended_values))
attended_values = self.dropout(attended_values)
return query_orig + attended_values # Residual connection
class GNNPreProcessor(nn.Module):
def __init__(self, input_dim, hidden_dim, gat_heads=8):
super(GNNPreProcessor, self).__init__()
self.gcn = GCNConv(input_dim, hidden_dim)
self.gat = GATConv(hidden_dim, hidden_dim, heads=gat_heads, concat=False)
self.alpha = nn.Parameter(torch.tensor(0.5))
def forward(self, x, edge_index):
batch_size, seq_len, feature_dim = x.size()
x = x.view(batch_size * seq_len, feature_dim)
edge_index = edge_index.view(2, -1)
x_gcn = F.relu(self.gcn(x, edge_index))
x_gat = F.relu(self.gat(x, edge_index))
x = self.alpha * x_gcn + (1 - self.alpha) * x_gat
x = x.view(batch_size, seq_len, -1)
return x
class DEBERTAMultiGATAttentionModel(PreTrainedModel):
config_class = BERTMultiGATAttentionConfig
def __init__(self, config):
super(DEBERTAMultiGATAttentionModel, self).__init__(config)
self.config = config
self.transformer =AutoModel.from_pretrained(config.transformer_model)
self.gnn_preprocessor1 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
self.gnn_preprocessor2 = GNNPreProcessor(config.gnn_input_dim, config.gnn_hidden_dim)
self.fc_combine = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.layer_norm_combine = nn.LayerNorm(config.hidden_size)
self.dropout_combine = nn.Dropout(config.dropout)
self.self_attention1 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.self_attention2 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.cross_attention = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.self_attention3 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.self_attention4 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.cross_attention_ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.self_attention5 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.self_attention6 = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.cross_attention__ = MultiHeadGATAttention(config.hidden_size, config.num_heads, config.dropout)
self.fc1 = nn.Linear(config.hidden_size * 2, 256)
self.fc2 = nn.Linear(config.hidden_size * 2, 256)
self.fc3 = nn.Linear(config.hidden_size * 2, 256)
self.layer_norm_fc1 = nn.LayerNorm(256)
self.layer_norm_fc2 = nn.LayerNorm(256)
self.layer_norm_fc3 = nn.LayerNorm(256)
self.dropout1 = nn.Dropout(config.dropout)
self.dropout2 = nn.Dropout(config.dropout)
self.dropout3 = nn.Dropout(config.dropout)
self.dropout4 = nn.Dropout(config.dropout)
self.fc_proj = nn.Linear(256, 256)
self.layer_norm_proj = nn.LayerNorm(256)
self.fc_final = nn.Linear(256, 1)
def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, edge_index1, edge_index2):
output1_bert = self.transformer(input_ids1, attention_mask1)[0]
output2_bert = self.transformer(input_ids2, attention_mask2)[0]
edge_index1 = edge_index1.view(2, -1) # Flatten the batch dimension
edge_index2 = edge_index2.view(2, -1) # Flatten the batch dimension
output1_gnn = self.gnn_preprocessor1(output1_bert, edge_index1)
output2_gnn = self.gnn_preprocessor2(output2_bert, edge_index2)
combined_output1 = torch.cat([output1_bert, output1_gnn], dim=2)
combined_output2 = torch.cat([output2_bert, output2_gnn], dim=2)
combined_output1 = self.layer_norm_combine(self.fc_combine(combined_output1))
combined_output2 = self.layer_norm_combine(self.fc_combine(combined_output2))
combined_output1 = self.dropout_combine(F.relu(combined_output1))
combined_output2 = self.dropout_combine(F.relu(combined_output2))
#
output1 = self.self_attention1(combined_output1, combined_output1, combined_output1, edge_index1)
output2 = self.self_attention2(combined_output2, combined_output2, combined_output2, edge_index2)
attended_output = self.cross_attention(output1, output2, output2, edge_index1)
combined_output = torch.cat([output1, attended_output], dim=2)
combined_output, _ = torch.max(combined_output, dim=1)
combined_output = self.layer_norm_fc1(self.fc2(combined_output))
combined_output = self.dropout1(F.relu(combined_output))
combined_output = combined_output.unsqueeze(1)
#
output1 = self.self_attention3(combined_output1, combined_output1, combined_output1, edge_index1)
output2 = self.self_attention4(combined_output2, combined_output2, combined_output2, edge_index2)
attended_output = self.cross_attention_(output1, output2, output2, edge_index1)
combined_output = torch.cat([output1, attended_output], dim=2)
combined_output, _ = torch.max(combined_output, dim=1)
combined_output = self.layer_norm_fc2(self.fc2(combined_output))
combined_output = self.dropout2(F.relu(combined_output))
combined_output = combined_output.unsqueeze(1)
#
output1 = self.self_attention5(combined_output1, combined_output1, combined_output1, edge_index1)
output2 = self.self_attention6(combined_output2, combined_output2, combined_output2, edge_index2)
attended_output = self.cross_attention__(output1, output2, output2, edge_index1)
combined_output = torch.cat([output1, attended_output], dim=2)
combined_output, _ = torch.max(combined_output, dim=1)
combined_output = self.layer_norm_fc1(self.fc3(combined_output))
combined_output = self.dropout3(F.relu(combined_output))
combined_output = combined_output.unsqueeze(1)
hidden_state_proj = self.layer_norm_proj(self.fc_proj(combined_output))
hidden_state_proj = self.dropout4(hidden_state_proj)
final = self.fc_final(hidden_state_proj)
return torch.sigmoid(final)
AutoModel.register(BERTMultiGATAttentionConfig, DEBERTAMultiGATAttentionModel)