File size: 2,494 Bytes
461f64f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """
Immutable configurations enabling to share common fields across the specific models used.
"""
from abc import ABC
from dataclasses import dataclass
from typing import ClassVar
@dataclass(frozen=True)
class BaseModelConfig(ABC):
"""
Container storing configurations useful across all QA models.
"""
@dataclass(frozen=True)
class AlwaysNoAnswerModelConfig(BaseModelConfig):
"""
Trivial baseline that always predicts no-answer ("").
"""
MODEL_TYPE: ClassVar[str] = "always_no_answer"
@dataclass(frozen=True)
class SentenceEmbeddingModelConfig(BaseModelConfig):
"""
Config object for the simpler baseline model.
"""
# Ensuring that MODEL_TYPE is not treated as an object field (e.g., not added to __eq__() etc.)
# as it is common across all objects of the dataclass
MODEL_TYPE: ClassVar[str] = "embedding_best_sentence"
# TODO - consider switching to other defaults for non-Apple users
device: str = "mps"
sentence_model_name: str = "all-MiniLM-L6-v2"
no_answer_threshold: float = 0.5
@dataclass(frozen=True)
class BertQAConfig(BaseModelConfig, ABC):
"""
Shared super-class config to be sub-classed by BERT model variants.
"""
# Specifying fields to be materialized by sub-classes to avoid Pylance complaints
backbone_name: str
max_sequence_length: int
learning_rate: float
num_epochs: int
batch_size: int
eval_batch_size: int
no_answer_threshold: float
device: str = "cuda"
@dataclass(frozen=True)
class TinyBertQAConfig(BertQAConfig):
"""
Config for a Tiny BERT-based extractive QA system.
"""
MODEL_TYPE: ClassVar[str] = "tinybert_qa"
backbone_name: str = (
"huawei-noah/TinyBERT_General_4L_312D" # General-purpose checkpoint (not QA-tuned)
)
max_sequence_length: int = 256
learning_rate: float = 2e-5
num_epochs: int = 5
batch_size: int = 64
eval_batch_size: int = 2048
no_answer_threshold: float = 0.0
@dataclass(frozen=True)
class OriginalBertQAConfig(BertQAConfig):
"""
Config for a BERT-based extractive QA system (original BERT model).
"""
MODEL_TYPE: ClassVar[str] = "original_bert_qa"
backbone_name: str = (
"bert-base-uncased" # General-purpose checkpoint (not QA-tuned)
)
max_sequence_length: int = 384
learning_rate: float = 5e-5
num_epochs: int = 2
batch_size: int = 48
eval_batch_size: int = 1024
no_answer_threshold: float = 0.5
|