import torch.nn as nn from transformers import AutoModel class MultiTaskModel(nn.Module): def __init__(self, encoder_name, num_category_labels, num_urgency_labels): super().__init__() self.encoder = AutoModel.from_pretrained(encoder_name) hidden_size = self.encoder.config.hidden_size # Changed from category_head to category_classifier self.category_classifier = nn.Linear(hidden_size, num_category_labels) # Changed from urgency_head to urgency_classifier self.urgency_classifier = nn.Linear(hidden_size, num_urgency_labels) def forward(self, input_ids, attention_mask): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask ) pooled = outputs.last_hidden_state[:, 0] return type( "Output", (), { "category_logits": self.category_classifier(pooled), "urgency_logits": self.urgency_classifier(pooled), } )()