File size: 1,443 Bytes
0584798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import nn
from transformers import AutoModel


@dataclass(frozen=True)
class MultiTaskLabelSizes:
    intent_type: int
    intent_subtype: int
    decision_phase: int


class MultiTaskIntentModel(nn.Module):
    def __init__(self, base_model_name: str, label_sizes: MultiTaskLabelSizes):
        super().__init__()
        self.base_model_name = base_model_name
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden_size = int(self.encoder.config.hidden_size)
        self.dropout = nn.Dropout(float(getattr(self.encoder.config, "seq_classif_dropout", 0.2)))
        self.intent_type_head = nn.Linear(hidden_size, label_sizes.intent_type)
        self.intent_subtype_head = nn.Linear(hidden_size, label_sizes.intent_subtype)
        self.decision_phase_head = nn.Linear(hidden_size, label_sizes.decision_phase)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> dict[str, torch.Tensor]:
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        pooled = self.dropout(pooled)
        return {
            "intent_type_logits": self.intent_type_head(pooled),
            "intent_subtype_logits": self.intent_subtype_head(pooled),
            "decision_phase_logits": self.decision_phase_head(pooled),
        }