File size: 1,207 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
Always-no-answer baseline: returns a standardized null Prediction for every question.
"""

from typing import Dict, Optional
from src.models.base_qa_model import QAModel
from src.etl.types import QAExample, Prediction
from src.config.model_configs import AlwaysNoAnswerModelConfig


class AlwaysNoAnswerQAModel(QAModel):
    """
    Minimal baseline that predicts "" (no-answer) for all inputs.
    """

    def __init__(self, config: AlwaysNoAnswerModelConfig) -> None:
        super().__init__()
        assert isinstance(
            config, AlwaysNoAnswerModelConfig
        ), "Incompatible configuration object."
        self.config = config

    def train(
        self,
        train_examples: Optional[Dict[str, QAExample]] = None,
        val_examples: Optional[Dict[str, QAExample]] = None,
    ) -> None:
        """
        Nothing being explicitly trained for this model. Preserved for API consistency with super-class.
        """
        return

    def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
        assert isinstance(examples, dict), "Incompatible input examples type."
        return {qid: Prediction.null(question_id=qid) for qid in examples.keys()}