| |
| |
| |
| |
| |
| |
| |
| import torch |
| from src.model import SegmentationNetwork |
| from src.dataset import SegmentationTokenizer, SentenceSegmenter |
| from .config import configuration |
|
|
|
|
| |
| def load_model(model_path: str = None, tokenizer_path: str = None) -> tuple[SegmentationNetwork, SegmentationTokenizer, SentenceSegmenter]: |
| """ |
| Load the trained segmentation model, tokenizer, and segmenter. |
| :param model_path: The path to the trained segmentation model. |
| :param tokenizer_path: The path to the trained segmentation tokenizer. |
| :return: A tuple containing the model, tokenizer, and segmenter. |
| """ |
| |
| if model_path is None: |
| full_model_path = input('Enter the full path of the trained model: ') |
| else: |
| full_model_path = model_path |
|
|
| if tokenizer_path is None: |
| full_tokenizer_path = input('Enter the full path of the trained tokenizer: ') |
| else: |
| full_tokenizer_path = tokenizer_path |
|
|
| |
| train_config = configuration() |
| model_config = train_config.model_config |
| |
| model = SegmentationNetwork(model_config) |
| model_dict = torch.load(full_model_path, map_location='cpu') |
| model.load_state_dict(model_dict['model_state_dict']) |
| model.eval() |
| |
| tokenizer = SegmentationTokenizer( |
| vocab_size=model_config.vocab_size, |
| max_length=model_config.max_tokens |
| ).load(full_tokenizer_path) |
| |
| segmenter = SentenceSegmenter(max_sentences=model_config.max_sentences) |
| return model, tokenizer, segmenter |
| |
| |
| |
|
|