File size: 754 Bytes
461f64f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from abc import ABC, abstractmethod
from typing import Dict, Optional
from src.etl.types import QAExample, Prediction
class QAModel(ABC):
"""Basic contract dictating specific QA model implementation requirements."""
@abstractmethod
def train(
self,
train_examples: Dict[str, QAExample],
val_examples: Optional[Dict[str, QAExample]] = None,
) -> None:
"""
Trains the model; assumes uniqueness of keys of train_examples (unique question IDs).
"""
raise NotImplementedError
@abstractmethod
def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
"""
Produces one Prediction per question ID.
"""
raise NotImplementedError
|