|
|
""" |
|
|
Immutable configurations enabling to share common fields across the specific models used. |
|
|
""" |
|
|
|
|
|
from abc import ABC |
|
|
from dataclasses import dataclass |
|
|
from typing import ClassVar |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class BaseModelConfig(ABC): |
|
|
""" |
|
|
Container storing configurations useful across all QA models. |
|
|
""" |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AlwaysNoAnswerModelConfig(BaseModelConfig): |
|
|
""" |
|
|
Trivial baseline that always predicts no-answer (""). |
|
|
""" |
|
|
|
|
|
MODEL_TYPE: ClassVar[str] = "always_no_answer" |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class SentenceEmbeddingModelConfig(BaseModelConfig): |
|
|
""" |
|
|
Config object for the simpler baseline model. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
MODEL_TYPE: ClassVar[str] = "embedding_best_sentence" |
|
|
|
|
|
device: str = "mps" |
|
|
sentence_model_name: str = "all-MiniLM-L6-v2" |
|
|
no_answer_threshold: float = 0.5 |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class BertQAConfig(BaseModelConfig, ABC): |
|
|
""" |
|
|
Shared super-class config to be sub-classed by BERT model variants. |
|
|
""" |
|
|
|
|
|
|
|
|
backbone_name: str |
|
|
max_sequence_length: int |
|
|
learning_rate: float |
|
|
num_epochs: int |
|
|
batch_size: int |
|
|
eval_batch_size: int |
|
|
no_answer_threshold: float |
|
|
device: str = "cuda" |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class TinyBertQAConfig(BertQAConfig): |
|
|
""" |
|
|
Config for a Tiny BERT-based extractive QA system. |
|
|
""" |
|
|
|
|
|
MODEL_TYPE: ClassVar[str] = "tinybert_qa" |
|
|
backbone_name: str = ( |
|
|
"huawei-noah/TinyBERT_General_4L_312D" |
|
|
) |
|
|
max_sequence_length: int = 256 |
|
|
learning_rate: float = 2e-5 |
|
|
num_epochs: int = 5 |
|
|
batch_size: int = 64 |
|
|
eval_batch_size: int = 2048 |
|
|
no_answer_threshold: float = 0.0 |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class OriginalBertQAConfig(BertQAConfig): |
|
|
""" |
|
|
Config for a BERT-based extractive QA system (original BERT model). |
|
|
""" |
|
|
|
|
|
MODEL_TYPE: ClassVar[str] = "original_bert_qa" |
|
|
backbone_name: str = ( |
|
|
"bert-base-uncased" |
|
|
) |
|
|
max_sequence_length: int = 384 |
|
|
learning_rate: float = 5e-5 |
|
|
num_epochs: int = 2 |
|
|
batch_size: int = 48 |
|
|
eval_batch_size: int = 1024 |
|
|
no_answer_threshold: float = 0.5 |
|
|
|