""" Immutable configurations enabling to share common fields across the specific models used. """ from abc import ABC from dataclasses import dataclass from typing import ClassVar @dataclass(frozen=True) class BaseModelConfig(ABC): """ Container storing configurations useful across all QA models. """ @dataclass(frozen=True) class AlwaysNoAnswerModelConfig(BaseModelConfig): """ Trivial baseline that always predicts no-answer (""). """ MODEL_TYPE: ClassVar[str] = "always_no_answer" @dataclass(frozen=True) class SentenceEmbeddingModelConfig(BaseModelConfig): """ Config object for the simpler baseline model. """ # Ensuring that MODEL_TYPE is not treated as an object field (e.g., not added to __eq__() etc.) # as it is common across all objects of the dataclass MODEL_TYPE: ClassVar[str] = "embedding_best_sentence" # TODO - consider switching to other defaults for non-Apple users device: str = "mps" sentence_model_name: str = "all-MiniLM-L6-v2" no_answer_threshold: float = 0.5 @dataclass(frozen=True) class BertQAConfig(BaseModelConfig, ABC): """ Shared super-class config to be sub-classed by BERT model variants. """ # Specifying fields to be materialized by sub-classes to avoid Pylance complaints backbone_name: str max_sequence_length: int learning_rate: float num_epochs: int batch_size: int eval_batch_size: int no_answer_threshold: float device: str = "cuda" @dataclass(frozen=True) class TinyBertQAConfig(BertQAConfig): """ Config for a Tiny BERT-based extractive QA system. """ MODEL_TYPE: ClassVar[str] = "tinybert_qa" backbone_name: str = ( "huawei-noah/TinyBERT_General_4L_312D" # General-purpose checkpoint (not QA-tuned) ) max_sequence_length: int = 256 learning_rate: float = 2e-5 num_epochs: int = 5 batch_size: int = 64 eval_batch_size: int = 2048 no_answer_threshold: float = 0.0 @dataclass(frozen=True) class OriginalBertQAConfig(BertQAConfig): """ Config for a BERT-based extractive QA system (original BERT model). """ MODEL_TYPE: ClassVar[str] = "original_bert_qa" backbone_name: str = ( "bert-base-uncased" # General-purpose checkpoint (not QA-tuned) ) max_sequence_length: int = 384 learning_rate: float = 5e-5 num_epochs: int = 2 batch_size: int = 48 eval_batch_size: int = 1024 no_answer_threshold: float = 0.5