File size: 2,494 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Immutable configurations enabling to share common fields across the specific models used.
"""

from abc import ABC
from dataclasses import dataclass
from typing import ClassVar


@dataclass(frozen=True)
class BaseModelConfig(ABC):
    """
    Container storing configurations useful across all QA models.
    """


@dataclass(frozen=True)
class AlwaysNoAnswerModelConfig(BaseModelConfig):
    """
    Trivial baseline that always predicts no-answer ("").
    """

    MODEL_TYPE: ClassVar[str] = "always_no_answer"


@dataclass(frozen=True)
class SentenceEmbeddingModelConfig(BaseModelConfig):
    """
    Config object for the simpler baseline model.
    """

    # Ensuring that MODEL_TYPE is not treated as an object field (e.g., not added to __eq__() etc.)
    # as it is common across all objects of the dataclass
    MODEL_TYPE: ClassVar[str] = "embedding_best_sentence"
    # TODO - consider switching to other defaults for non-Apple users
    device: str = "mps"
    sentence_model_name: str = "all-MiniLM-L6-v2"
    no_answer_threshold: float = 0.5


@dataclass(frozen=True)
class BertQAConfig(BaseModelConfig, ABC):
    """
    Shared super-class config to be sub-classed by BERT model variants.
    """

    # Specifying fields to be materialized by sub-classes to avoid Pylance complaints
    backbone_name: str
    max_sequence_length: int
    learning_rate: float
    num_epochs: int
    batch_size: int
    eval_batch_size: int
    no_answer_threshold: float
    device: str = "cuda"


@dataclass(frozen=True)
class TinyBertQAConfig(BertQAConfig):
    """
    Config for a Tiny BERT-based extractive QA system.
    """

    MODEL_TYPE: ClassVar[str] = "tinybert_qa"
    backbone_name: str = (
        "huawei-noah/TinyBERT_General_4L_312D"  # General-purpose checkpoint (not QA-tuned)
    )
    max_sequence_length: int = 256
    learning_rate: float = 2e-5
    num_epochs: int = 5
    batch_size: int = 64
    eval_batch_size: int = 2048
    no_answer_threshold: float = 0.0


@dataclass(frozen=True)
class OriginalBertQAConfig(BertQAConfig):
    """
    Config for a BERT-based extractive QA system (original BERT model).
    """

    MODEL_TYPE: ClassVar[str] = "original_bert_qa"
    backbone_name: str = (
        "bert-base-uncased"  # General-purpose checkpoint (not QA-tuned)
    )
    max_sequence_length: int = 384
    learning_rate: float = 5e-5
    num_epochs: int = 2
    batch_size: int = 48
    eval_batch_size: int = 1024
    no_answer_threshold: float = 0.5