hyoo14 commited on
Commit
83da537
·
verified ·
1 Parent(s): 9bcaf0f

Upload modeling_mp_rna.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mp_rna.py +23 -0
modeling_mp_rna.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class CustomMPRNAForSequenceClassification(nn.Module):
6
+ def __init__(self, base_model, num_labels):
7
+ super().__init__()
8
+ self.base_model = base_model
9
+ self.num_labels = num_labels
10
+ self.classifier = nn.Linear(base_model.config.hidden_size, num_labels)
11
+ self.dropout = nn.Dropout(0.1)
12
+
13
+ def forward(self, input_ids, attention_mask=None, labels=None):
14
+ outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
15
+ pooled_output = outputs[0][:, 0, :]
16
+ pooled_output = self.dropout(pooled_output)
17
+ logits = self.classifier(pooled_output)
18
+
19
+ loss = None
20
+ if labels is not None:
21
+ loss_fct = nn.CrossEntropyLoss()
22
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
23
+ return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits}