Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| class ArticleClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| num_fields, | |
| num_quartiles=4, | |
| dropout_prob=0.3, | |
| field_weights=None, | |
| quartile_weights=None, | |
| model_name="allenai/scibert_scivocab_cased", | |
| field_head=nn.Linear, | |
| quartile_head=nn.Linear, | |
| shared_head=nn.Identity, | |
| quartile_loss_weight=1.0, | |
| by_cls=True, | |
| ): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| hidden = self.encoder.config.hidden_size | |
| self.shared = shared_head(hidden, hidden) | |
| self.field_head = field_head(hidden, num_fields) | |
| self.quartile_head = quartile_head(hidden, num_quartiles) | |
| self.dropout = nn.Dropout(dropout_prob) | |
| self.criterion_field = nn.CrossEntropyLoss(weight=field_weights) | |
| self.criterion_quartile = nn.CrossEntropyLoss(weight=quartile_weights) | |
| self.quartile_loss_weight = quartile_loss_weight | |
| self.by_cls = by_cls | |
| def forward( | |
| self, | |
| input_ids, | |
| attention_mask, | |
| field_label=None, | |
| quartile_label=None, | |
| ): | |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| if self.by_cls: | |
| cls = self.dropout(out.last_hidden_state[:, 0, :]) | |
| shared = self.shared(cls) | |
| else: | |
| last_hidden = out.last_hidden_state | |
| mask = attention_mask.unsqueeze(-1) | |
| pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | |
| pooled = self.dropout(pooled) | |
| shared = self.shared(pooled) | |
| field_logits = self.field_head(shared) | |
| quartile_logits = self.quartile_head(shared) | |
| loss = None | |
| if field_label is not None and quartile_label is not None: | |
| loss = ( | |
| self.criterion_field(field_logits, field_label) | |
| + self.quartile_loss_weight | |
| * self.criterion_quartile(quartile_logits, quartile_label) | |
| ) | |
| output = SequenceClassifierOutput(loss=loss, logits=field_logits) | |
| output.quartile_logits = quartile_logits | |
| return output | |
| class Head(nn.Module): | |
| def __init__(self, in_dim, out_dim): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(in_dim), | |
| nn.Linear(in_dim, 128), | |
| nn.GELU(), | |
| nn.Linear(128, out_dim), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) |