File size: 936 Bytes
cf8c315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import torch
import torch.nn as nn
from transformers import AutoModel

class TechSupportClassifier(nn.Module):
    def __init__(self, checkpoint="distilbert/distilbert-base-uncased", num_category_labels=5, num_urgency_labels=3):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(checkpoint)
        hidden_size = self.encoder.config.hidden_size
        self.category_classifier = nn.Linear(hidden_size, num_category_labels)
        self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:,0,:]
        category_logits = self.category_classifier(pooled)
        urgency_logits = self.urgency_classifier(pooled)
        return type("Output", (), {"category_logits": category_logits, "urgency_logits": urgency_logits})()