squad2-qa / src /models /base_qa_model.py
Kimis Perros
Initial deployment
461f64f
from abc import ABC, abstractmethod
from typing import Dict, Optional
from src.etl.types import QAExample, Prediction
class QAModel(ABC):
"""Basic contract dictating specific QA model implementation requirements."""
@abstractmethod
def train(
self,
train_examples: Dict[str, QAExample],
val_examples: Optional[Dict[str, QAExample]] = None,
) -> None:
"""
Trains the model; assumes uniqueness of keys of train_examples (unique question IDs).
"""
raise NotImplementedError
@abstractmethod
def predict(self, examples: Dict[str, QAExample]) -> Dict[str, Prediction]:
"""
Produces one Prediction per question ID.
"""
raise NotImplementedError