from __future__ import annotations from dataclasses import dataclass import torch from torch import nn from transformers import AutoModel @dataclass(frozen=True) class MultiTaskLabelSizes: intent_type: int intent_subtype: int decision_phase: int class MultiTaskIntentModel(nn.Module): def __init__(self, base_model_name: str, label_sizes: MultiTaskLabelSizes): super().__init__() self.base_model_name = base_model_name self.encoder = AutoModel.from_pretrained(base_model_name) hidden_size = int(self.encoder.config.hidden_size) self.dropout = nn.Dropout(float(getattr(self.encoder.config, "seq_classif_dropout", 0.2))) self.intent_type_head = nn.Linear(hidden_size, label_sizes.intent_type) self.intent_subtype_head = nn.Linear(hidden_size, label_sizes.intent_subtype) self.decision_phase_head = nn.Linear(hidden_size, label_sizes.decision_phase) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict[str, torch.Tensor]: outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state[:, 0] pooled = self.dropout(pooled) return { "intent_type_logits": self.intent_type_head(pooled), "intent_subtype_logits": self.intent_subtype_head(pooled), "decision_phase_logits": self.decision_phase_head(pooled), }