medico2025 / qtype.py
alvikhan's picture
restructured files
62305fe
import torch
from torch import nn
from transformers import DistilBertModel
class QuestionTypeClassifier(nn.Module):
def __init__(self, num_types):
super().__init__()
# Load pre-trained DistilBERT
self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
# Classification head
self.fc = nn.Linear(self.distilbert.config.hidden_size, num_types)
def forward(self, input_ids, attention_mask):
outputs = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask
)
# Take [CLS] token embedding (DistilBERT uses first token as [CLS])
cls_token = outputs.last_hidden_state[:, 0, :] # [B, hidden]
logits = self.fc(cls_token) # [B, num_types]
return logits