alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
from src.model import SegmentationNetwork
from src.dataset import SegmentationTokenizer, SentenceSegmenter
from .config import configuration
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
def load_model(model_path: str = None, tokenizer_path: str = None) -> tuple[SegmentationNetwork, SegmentationTokenizer, SentenceSegmenter]:
"""
Load the trained segmentation model, tokenizer, and segmenter.
:param model_path: The path to the trained segmentation model.
:param tokenizer_path: The path to the trained segmentation tokenizer.
:return: A tuple containing the model, tokenizer, and segmenter.
"""
# Ask for data paths:
if model_path is None:
full_model_path = input('Enter the full path of the trained model: ')
else:
full_model_path = model_path
if tokenizer_path is None:
full_tokenizer_path = input('Enter the full path of the trained tokenizer: ')
else:
full_tokenizer_path = tokenizer_path
# Load configs:
train_config = configuration()
model_config = train_config.model_config
# Load model:
model = SegmentationNetwork(model_config)
model_dict = torch.load(full_model_path, map_location='cpu')
model.load_state_dict(model_dict['model_state_dict'])
model.eval()
# Load tokenizer:
tokenizer = SegmentationTokenizer(
vocab_size=model_config.vocab_size,
max_length=model_config.max_tokens
).load(full_tokenizer_path)
# Load segmenter:
segmenter = SentenceSegmenter(max_sentences=model_config.max_sentences)
return model, tokenizer, segmenter
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #