tech-support-classifier / modeling_tech_support.py
Sandei's picture
Upload trained MultiTaskModel with custom class
cf8c315 verified
raw
history blame contribute delete
936 Bytes
import torch
import torch.nn as nn
from transformers import AutoModel
class TechSupportClassifier(nn.Module):
def __init__(self, checkpoint="distilbert/distilbert-base-uncased", num_category_labels=5, num_urgency_labels=3):
super().__init__()
self.encoder = AutoModel.from_pretrained(checkpoint)
hidden_size = self.encoder.config.hidden_size
self.category_classifier = nn.Linear(hidden_size, num_category_labels)
self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:,0,:]
category_logits = self.category_classifier(pooled)
urgency_logits = self.urgency_classifier(pooled)
return type("Output", (), {"category_logits": category_logits, "urgency_logits": urgency_logits})()