File size: 836 Bytes
62305fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn
from transformers import DistilBertModel

class QuestionTypeClassifier(nn.Module):
    def __init__(self, num_types):
        super().__init__()
        
        # Load pre-trained DistilBERT
        self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        
        # Classification head
        self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types)

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Take [CLS] token embedding (DistilBERT uses first token as [CLS])
        cls_token = outputs.last_hidden_state[:, 0, :]  # [B, hidden]
        
        logits = self.fc(cls_token)  # [B, num_types]
        return logits