File size: 618 Bytes
8fa3acc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# pylint: disable=method-hidden

from abc import abstractmethod, ABC
from typing import List

from datasets import Dataset
from datasets.formatting.formatting import LazyRow

from src.task.task import Task


class LanguageModel(ABC):
    def __init__(self, model_name: str):
        self.name = model_name

    @abstractmethod
    def predict(self, evaluation_dataset: Dataset, task: Task) -> List:
        raise NotImplementedError

    @abstractmethod
    def infer(self, rows: LazyRow):
        raise NotImplementedError

    @abstractmethod
    def generate(self, rows: LazyRow):
        raise NotImplementedError