File size: 1,812 Bytes
391f233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Joint Intent + NER model class — inference-only slice.

Original training script (with dataset, eval loop, plotting) lives in the
training repo. This file only carries the nn.Module so the deployed pipeline
can load the checkpoint without pulling in matplotlib/seqeval/sklearn/tqdm.
"""

import torch.nn as nn
from transformers import AutoModel


class JointIntentNERModel(nn.Module):
    """Shared BanglaBERT encoder feeding an intent head ([CLS] token) and a
    token-level NER head."""

    def __init__(self, model_name: str, num_intents: int, num_ner_labels: int,
                 dropout: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size

        self.intent_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_intents),
        )

        self.ner_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_ner_labels),
        )

        self.num_intents = num_intents
        self.num_ner_labels = num_ner_labels

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        cls_output = outputs.last_hidden_state[:, 0, :]
        intent_logits = self.intent_classifier(cls_output)
        ner_logits = self.ner_classifier(outputs.last_hidden_state)
        return intent_logits, ner_logits