File size: 754 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from abc import ABC, abstractmethod
from typing import Dict, Optional
from src.etl.types import QAExample, Prediction


class QAModel(ABC):
    """Basic contract dictating specific QA model implementation requirements."""

    @abstractmethod
    def train(
        self,
        train_examples: Dict[str, QAExample],
        val_examples: Optional[Dict[str, QAExample]] = None,
    ) -> None:
        """
        Trains the model; assumes uniqueness of keys of train_examples (unique question IDs).
        """
        raise NotImplementedError

    @abstractmethod
    def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
        """
        Produces one Prediction per question ID.
        """
        raise NotImplementedError