File size: 2,234 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from datasets import DatasetDict
from typing import Dict
from abc import ABC, abstractmethod
from transformers import (
PreTrainedTokenizerBase
)
import shutil
import os
import glob
from torch.utils.tensorboard import SummaryWriter
from larm.common.config import Config
from .base_model import BaseModel
class BaseRunner(ABC):
def __init__(
self,
model: BaseModel,
processing_class: PreTrainedTokenizerBase,
datasets_dict: DatasetDict,
configs: Config,
env_and_gen_dict: Dict
):
self.model = model
self.configs = configs
# parse dataset
assert len(datasets_dict) == 1
self.dataset_name = list(datasets_dict.keys())[0]
self.dataset_dict: DatasetDict = datasets_dict[self.dataset_name]
# prepare env
assert len(env_and_gen_dict) == 1
env_name = list(env_and_gen_dict.keys())[0]
self.env_cls, self.gen_cls = env_and_gen_dict[env_name]
# build chat template
self.processing_class = processing_class
@abstractmethod
def train(self):
raise NotImplementedError("Should be implemented by subclasses")
def evaluate(self):
evaluate_func_mapping = {
"STATIC": self._static_evaluate,
"DYNAMIC": self._dynamic_evaluate
}
evaluate_func = evaluate_func_mapping.get(self.env.ENV_CARD)
if evaluate_func is None:
raise ValueError("The env has unrecogonized ENV_CARD attribute")
return evaluate_func()
@abstractmethod
def _static_evaluate(self):
raise NotImplementedError("Should be implemented by subclasses")
@abstractmethod
def _dynamic_evaluate(self):
raise NotImplementedError("Should be implemented by subclasses")
def _create_tensorboard(self, mode: str):
log_dir = os.path.join(self.save_dir, "runs")
writer = SummaryWriter(log_dir=log_dir)
return writer
def _remove_trainer_ckpts(self, output_dir: str):
ckpt_paths = glob.glob(os.path.join(output_dir, "checkpoint-*"))
for ckpt in ckpt_paths:
shutil.rmtree(ckpt, ignore_errors=True) |