squad2-qa / src /config /model_configs.py
Kimis Perros
Initial deployment
461f64f
"""
Immutable configurations enabling to share common fields across the specific models used.
"""
from abc import ABC
from dataclasses import dataclass
from typing import ClassVar
@dataclass(frozen=True)
class BaseModelConfig(ABC):
"""
Container storing configurations useful across all QA models.
"""
@dataclass(frozen=True)
class AlwaysNoAnswerModelConfig(BaseModelConfig):
"""
Trivial baseline that always predicts no-answer ("").
"""
MODEL_TYPE: ClassVar[str] = "always_no_answer"
@dataclass(frozen=True)
class SentenceEmbeddingModelConfig(BaseModelConfig):
"""
Config object for the simpler baseline model.
"""
# Ensuring that MODEL_TYPE is not treated as an object field (e.g., not added to __eq__() etc.)
# as it is common across all objects of the dataclass
MODEL_TYPE: ClassVar[str] = "embedding_best_sentence"
# TODO - consider switching to other defaults for non-Apple users
device: str = "mps"
sentence_model_name: str = "all-MiniLM-L6-v2"
no_answer_threshold: float = 0.5
@dataclass(frozen=True)
class BertQAConfig(BaseModelConfig, ABC):
"""
Shared super-class config to be sub-classed by BERT model variants.
"""
# Specifying fields to be materialized by sub-classes to avoid Pylance complaints
backbone_name: str
max_sequence_length: int
learning_rate: float
num_epochs: int
batch_size: int
eval_batch_size: int
no_answer_threshold: float
device: str = "cuda"
@dataclass(frozen=True)
class TinyBertQAConfig(BertQAConfig):
"""
Config for a Tiny BERT-based extractive QA system.
"""
MODEL_TYPE: ClassVar[str] = "tinybert_qa"
backbone_name: str = (
"huawei-noah/TinyBERT_General_4L_312D" # General-purpose checkpoint (not QA-tuned)
)
max_sequence_length: int = 256
learning_rate: float = 2e-5
num_epochs: int = 5
batch_size: int = 64
eval_batch_size: int = 2048
no_answer_threshold: float = 0.0
@dataclass(frozen=True)
class OriginalBertQAConfig(BertQAConfig):
"""
Config for a BERT-based extractive QA system (original BERT model).
"""
MODEL_TYPE: ClassVar[str] = "original_bert_qa"
backbone_name: str = (
"bert-base-uncased" # General-purpose checkpoint (not QA-tuned)
)
max_sequence_length: int = 384
learning_rate: float = 5e-5
num_epochs: int = 2
batch_size: int = 48
eval_batch_size: int = 1024
no_answer_threshold: float = 0.5