|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Optional |
|
|
from src.etl.types import QAExample, Prediction |
|
|
|
|
|
|
|
|
class QAModel(ABC): |
|
|
"""Basic contract dictating specific QA model implementation requirements.""" |
|
|
|
|
|
@abstractmethod |
|
|
def train( |
|
|
self, |
|
|
train_examples: Dict[str, QAExample], |
|
|
val_examples: Optional[Dict[str, QAExample]] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Trains the model; assumes uniqueness of keys of train_examples (unique question IDs). |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@abstractmethod |
|
|
def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]: |
|
|
""" |
|
|
Produces one Prediction per question ID. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|