File size: 1,207 Bytes
461f64f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
"""
Always-no-answer baseline: returns a standardized null Prediction for every question.
"""
from typing import Dict, Optional
from src.models.base_qa_model import QAModel
from src.etl.types import QAExample, Prediction
from src.config.model_configs import AlwaysNoAnswerModelConfig
class AlwaysNoAnswerQAModel(QAModel):
"""
Minimal baseline that predicts "" (no-answer) for all inputs.
"""
def __init__(self, config: AlwaysNoAnswerModelConfig) -> None:
super().__init__()
assert isinstance(
config, AlwaysNoAnswerModelConfig
), "Incompatible configuration object."
self.config = config
def train(
self,
train_examples: Optional[Dict[str, QAExample]] = None,
val_examples: Optional[Dict[str, QAExample]] = None,
) -> None:
"""
Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
"""
return
def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
assert isinstance(examples, dict), "Incompatible input examples type."
return {qid: Prediction.null(question_id=qid) for qid in examples.keys()}
|