sci_classifier / src /model.py
nolongerlaugh's picture
Upload model.py
5365e48 verified
import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
class ArticleClassifier(nn.Module):
def __init__(
self,
num_fields,
num_quartiles=4,
dropout_prob=0.3,
field_weights=None,
quartile_weights=None,
model_name="allenai/scibert_scivocab_cased",
field_head=nn.Linear,
quartile_head=nn.Linear,
shared_head=nn.Identity,
quartile_loss_weight=1.0,
by_cls=True,
):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
hidden = self.encoder.config.hidden_size
self.shared = shared_head(hidden, hidden)
self.field_head = field_head(hidden, num_fields)
self.quartile_head = quartile_head(hidden, num_quartiles)
self.dropout = nn.Dropout(dropout_prob)
self.criterion_field = nn.CrossEntropyLoss(weight=field_weights)
self.criterion_quartile = nn.CrossEntropyLoss(weight=quartile_weights)
self.quartile_loss_weight = quartile_loss_weight
self.by_cls = by_cls
def forward(
self,
input_ids,
attention_mask,
field_label=None,
quartile_label=None,
):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
if self.by_cls:
cls = self.dropout(out.last_hidden_state[:, 0, :])
shared = self.shared(cls)
else:
last_hidden = out.last_hidden_state
mask = attention_mask.unsqueeze(-1)
pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
pooled = self.dropout(pooled)
shared = self.shared(pooled)
field_logits = self.field_head(shared)
quartile_logits = self.quartile_head(shared)
loss = None
if field_label is not None and quartile_label is not None:
loss = (
self.criterion_field(field_logits, field_label)
+ self.quartile_loss_weight
* self.criterion_quartile(quartile_logits, quartile_label)
)
output = SequenceClassifierOutput(loss=loss, logits=field_logits)
output.quartile_logits = quartile_logits
return output
class Head(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, 128),
nn.GELU(),
nn.Linear(128, out_dim),
)
def forward(self, x):
return self.net(x)