seqxgpt-detector / model.py
zcahjl3's picture
Upload model.py with huggingface_hub
0a7c540 verified
import torch
import torch.nn as nn
from typing import List, Tuple
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers.models.bert import BertModel
from fastNLP.modules.torch import MLP,ConditionalRandomField,allowed_transitions
from torch.nn import CrossEntropyLoss
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
conv_dropout: float = 0.0,
conv_bias: bool = False,
):
super().__init__()
def block(n_in, n_out, k, stride=1, conv_bias=False):
padding = k // 2
return nn.Sequential(
nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=k, stride=stride, padding=padding, bias=conv_bias),
nn.Dropout(conv_dropout),
# nn.BatchNorm1d(n_out),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2)
)
in_d = 1
self.conv_layers = nn.ModuleList()
for _, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(in_d, dim, k, stride=stride, conv_bias=conv_bias))
in_d = dim
def forward(self, x):
# x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
class ModelWiseCNNClassifier(nn.Module):
def __init__(self, id2labels, dropout_rate=0.1):
super(ModelWiseCNNClassifier, self).__init__()
feature_enc_layers = [(64, 5, 1)] + [(128, 3, 1)] * 3 + [(64, 3, 1)]
self.conv = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
conv_dropout=0.0,
conv_bias=False,
)
embedding_size = 4 *64
self.norm = nn.LayerNorm(embedding_size)
self.label_num = len(id2labels)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
self.crf.trans_m.data *= 0
def conv_feat_extract(self, x):
out = self.conv(x)
out = out.transpose(1, 2)
return out
def forward(self, x, labels):
x = x.transpose(1, 2)
out1 = self.conv_feat_extract(x[:, 0:1, :])
out2 = self.conv_feat_extract(x[:, 1:2, :])
out3 = self.conv_feat_extract(x[:, 2:3, :])
out4 = self.conv_feat_extract(x[:, 3:4, :])
outputs = torch.cat((out1, out2, out3, out4), dim=2)
outputs = self.norm(outputs)
dropout_outputs = self.dropout(outputs)
logits = self.classifier(dropout_outputs)
if self.training:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
output = {'loss': loss, 'logits': logits}
else:
mask = labels.gt(-1)
paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
paths[mask==0] = -1
output = {'preds': paths, 'logits': logits}
pass
return output
class ModelWiseTransformerClassifier(nn.Module):
def __init__(self, id2labels, seq_len, intermediate_size = 512, num_layers=2, dropout_rate=0.1):
super(ModelWiseTransformerClassifier, self).__init__()
# feature_enc_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]
feature_enc_layers = [(64, 5, 1)] + [(128, 3, 1)] * 3 + [(64, 3, 1)]
self.conv = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
conv_dropout=0.0,
conv_bias=False,
)
self.seq_len = seq_len # MAX Seq_len
embedding_size = 4 *64
self.encoder_layer = TransformerEncoderLayer(
d_model=embedding_size,
nhead=16,
dim_feedforward=intermediate_size,
dropout=dropout_rate,
batch_first=True)
self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer,
num_layers=num_layers)
self.position_encoding = torch.zeros((seq_len, embedding_size))
for pos in range(seq_len):
for i in range(0, embedding_size, 2):
self.position_encoding[pos, i] = torch.sin(
torch.tensor(pos / (10000**((2 * i) / embedding_size))))
self.position_encoding[pos, i + 1] = torch.cos(
torch.tensor(pos / (10000**((2 *
(i + 1)) / embedding_size))))
self.norm = nn.LayerNorm(embedding_size)
self.label_num = len(id2labels)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
self.crf.trans_m.data *= 0
def conv_feat_extract(self, x):
out = self.conv(x)
out = out.transpose(1, 2)
return out
def forward(self, x, labels):
mask = labels.gt(-1)
padding_mask = ~mask
x = x.transpose(1, 2)
out1 = self.conv_feat_extract(x[:, 0:1, :])
out2 = self.conv_feat_extract(x[:, 1:2, :])
out3 = self.conv_feat_extract(x[:, 2:3, :])
out4 = self.conv_feat_extract(x[:, 3:4, :])
out = torch.cat((out1, out2, out3, out4), dim=2)
outputs = out + self.position_encoding.to(out.device)
outputs = self.norm(outputs)
outputs = self.encoder(outputs, src_key_padding_mask=padding_mask)
dropout_outputs = self.dropout(outputs)
logits = self.classifier(dropout_outputs)
if self.training:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
output = {'loss': loss, 'logits': logits}
else:
paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
paths[mask==0] = -1
output = {'preds': paths, 'logits': logits}
pass
return output
class TransformerOnlyClassifier(nn.Module):
def __init__(self, id2labels, seq_len, embedding_size=4, num_heads=2, intermediate_size=64, num_layers=2, dropout_rate=0.1):
super(TransformerOnlyClassifier, self).__init__()
self.encoder_layer = TransformerEncoderLayer(
d_model=embedding_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
dropout=dropout_rate,
batch_first=True)
self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer,
num_layers=num_layers)
self.position_encoding = torch.zeros((seq_len, embedding_size))
for pos in range(seq_len):
for i in range(0, embedding_size, 2):
self.position_encoding[pos, i] = torch.sin(
torch.tensor(pos / (10000**((2 * i) / embedding_size))))
self.position_encoding[pos, i + 1] = torch.cos(
torch.tensor(pos / (10000**((2 *
(i + 1)) / embedding_size))))
self.norm = nn.LayerNorm(embedding_size)
self.label_num = len(id2labels)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Sequential(nn.Linear(embedding_size, self.label_num))
self.crf = ConditionalRandomField(num_tags=self.label_num, allowed_transitions=allowed_transitions(id2labels))
self.crf.trans_m.data *= 0
def forward(self, inputs, labels):
mask = labels.gt(-1)
padding_mask = ~mask
outputs = inputs + self.position_encoding.to(inputs.device)
outputs = self.norm(outputs)
outputs = self.encoder(outputs, src_key_padding_mask=padding_mask)
dropout_outputs = self.dropout(outputs)
logits = self.classifier(dropout_outputs)
if self.training:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, self.label_num), labels.view(-1))
output = {'loss': loss, 'logits': logits}
else:
paths, scores = self.crf.viterbi_decode(logits=logits, mask=mask)
paths[mask==0] = -1
output = {'preds': paths, 'logits': logits}
pass
return output