import torch import torch.nn as nn from transformers import AutoModel from transformers.modeling_outputs import SequenceClassifierOutput class ArticleClassifier(nn.Module): def __init__( self, num_fields, num_quartiles=4, dropout_prob=0.3, field_weights=None, quartile_weights=None, model_name="allenai/scibert_scivocab_cased", field_head=nn.Linear, quartile_head=nn.Linear, shared_head=nn.Identity, quartile_loss_weight=1.0, by_cls=True, ): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) hidden = self.encoder.config.hidden_size self.shared = shared_head(hidden, hidden) self.field_head = field_head(hidden, num_fields) self.quartile_head = quartile_head(hidden, num_quartiles) self.dropout = nn.Dropout(dropout_prob) self.criterion_field = nn.CrossEntropyLoss(weight=field_weights) self.criterion_quartile = nn.CrossEntropyLoss(weight=quartile_weights) self.quartile_loss_weight = quartile_loss_weight self.by_cls = by_cls def forward( self, input_ids, attention_mask, field_label=None, quartile_label=None, ): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) if self.by_cls: cls = self.dropout(out.last_hidden_state[:, 0, :]) shared = self.shared(cls) else: last_hidden = out.last_hidden_state mask = attention_mask.unsqueeze(-1) pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) pooled = self.dropout(pooled) shared = self.shared(pooled) field_logits = self.field_head(shared) quartile_logits = self.quartile_head(shared) loss = None if field_label is not None and quartile_label is not None: loss = ( self.criterion_field(field_logits, field_label) + self.quartile_loss_weight * self.criterion_quartile(quartile_logits, quartile_label) ) output = SequenceClassifierOutput(loss=loss, logits=field_logits) output.quartile_logits = quartile_logits return output class Head(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.net = nn.Sequential( nn.LayerNorm(in_dim), nn.Linear(in_dim, 128), nn.GELU(), nn.Linear(128, out_dim), ) def forward(self, x): return self.net(x)