import torch import torch.nn as nn from transformers import AutoModel class TechSupportClassifier(nn.Module): def __init__(self, checkpoint="distilbert/distilbert-base-uncased", num_category_labels=5, num_urgency_labels=3): super().__init__() self.encoder = AutoModel.from_pretrained(checkpoint) hidden_size = self.encoder.config.hidden_size self.category_classifier = nn.Linear(hidden_size, num_category_labels) self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:,0,:] category_logits = self.category_classifier(pooled) urgency_logits = self.urgency_classifier(pooled) return type("Output", (), {"category_logits": category_logits, "urgency_logits": urgency_logits})()