wojood-api / Nested /trainers /BaseTrainer.py
TymaaHammouda's picture
Update path
6d1fbad
import os
import torch
import logging
import natsort
import glob
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
class BaseTrainer:
def __init__(
self,
model=None,
max_epochs=50,
optimizer=None,
scheduler=None,
loss=None,
train_dataloader=None,
val_dataloader=None,
test_dataloader=None,
log_interval=10,
summary_writer=None,
output_path=None,
clip=5,
patience=5
):
self.model = model
self.max_epochs = max_epochs
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.optimizer = optimizer
self.scheduler = scheduler
self.loss = loss
self.log_interval = log_interval
self.summary_writer = summary_writer
self.output_path = output_path
self.current_timestep = 0
self.current_epoch = 0
self.clip = clip
self.patience = patience
def tag(self, dataloader, is_train=True):
"""
Given a dataloader containing segments, predict the tags
:param dataloader: torch.utils.data.DataLoader
:param is_train: boolean - True for training model, False for evaluation
:return: Iterator
subwords (B x T x NUM_LABELS)- torch.Tensor - BERT subword ID
gold_tags (B x T x NUM_LABELS) - torch.Tensor - ground truth tags IDs
tokens - List[Nested.data.dataset.Token] - list of tokens
valid_len (B x 1) - int - valiud length of each sequence
logits (B x T x NUM_LABELS) - logits for each token and each tag
"""
for subwords, gold_tags, tokens, valid_len in dataloader:
self.model.train(is_train)
if torch.cuda.is_available():
subwords = subwords.cuda()
gold_tags = gold_tags.cuda()
if is_train:
self.optimizer.zero_grad()
logits = self.model(subwords)
else:
with torch.no_grad():
logits = self.model(subwords)
yield subwords, gold_tags, tokens, valid_len, logits
def segments_to_file(self, segments, filename):
"""
Write segments to file
:param segments: [List[Nested.data.dataset.Token]] - list of list of tokens
:param filename: str - output filename
:return: None
"""
with open(filename, "w") as fh:
results = "\n\n".join(["\n".join([t.__str__() for t in segment]) for segment in segments])
fh.write("Token\tGold Tag\tPredicted Tag\n")
fh.write(results)
logging.info("Predictions written to %s", filename)
def save(self):
"""
Save model checkpoint
:return:
"""
filename = os.path.join(
self.output_path,
"checkpoints",
"checkpoint_{}.pt".format(self.current_epoch),
)
checkpoint = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"epoch": self.current_epoch
}
logger.info("Saving checkpoint to %s", filename)
torch.save(checkpoint, filename)
def load(self, checkpoint_path):
"""
Load model checkpoint
:param checkpoint_path: str - path/to/checkpoints
:return: None
"""
# checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
checkpoint_path = natsort.natsorted(checkpoint_path)
# checkpoint_path = checkpoint_path[-1]
logger.info("Loading checkpoint %s", checkpoint_path)
device = None if torch.cuda.is_available() else torch.device('cpu')
# checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
repo_path = snapshot_download(repo_id="SinaLab/Nested")
model_file = os.path.join(repo_path, "checkpoints", "checkpoint_2.pt")
checkpoint = torch.load(model_file, map_location=device, weights_only=False)
self.model.load_state_dict(checkpoint["model"])