toxicity / utils /model_classes.py
rudradcruze's picture
upload toxicity api application
1c25c67
import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Self-Attention mechanism"""
def __init__(self, embed_dim, num_heads, dropout=0.3):
super(MultiHeadSelfAttention, self).__init__()
self.attention = nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
self.layer_norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_output, _ = self.attention(x, x, x)
x = self.layer_norm(x + self.dropout(attn_output))
return x
class MHSA_GRU(nn.Module):
"""Multi-Head Self-Attention with GRU model"""
def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3):
super(MHSA_GRU, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.input_projection = nn.Linear(input_dim, hidden_dim)
self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
self.gru = nn.GRU(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_gru_layers,
batch_first=True,
dropout=dropout if num_gru_layers > 1 else 0,
bidirectional=False
)
self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout)
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
self.fc3 = nn.Linear(hidden_dim // 4, 1)
self.bn1 = nn.BatchNorm1d(hidden_dim // 2)
self.bn2 = nn.BatchNorm1d(hidden_dim // 4)
def forward(self, x):
batch_size = x.size(0)
x = self.input_projection(x)
x = x.unsqueeze(1)
x = self.mhsa1(x)
x = self.mhsa2(x)
gru_out, hidden = self.gru(x)
x = self.mhsa3(gru_out)
x = x[:, -1, :]
x = self.dropout(x)
x = torch.relu(self.bn1(self.fc1(x)))
x = self.dropout(x)
x = torch.relu(self.bn2(self.fc2(x)))
x = self.dropout(x)
x = self.fc3(x)
return torch.sigmoid(x)