Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import html | |
| import string | |
| import torch | |
| import config | |
| import unicodedata | |
| from nltk.tokenize import word_tokenize | |
| from dataset import XRayDataset | |
| from model import EncoderDecoderNet | |
| from torch.utils.data import Subset | |
| from sklearn.model_selection import train_test_split as sklearn_train_test_split | |
| def load_dataset(raw_caption=False): | |
| return XRayDataset( | |
| root=config.DATASET_PATH, | |
| transform=config.basic_transforms, | |
| freq_threshold=config.VOCAB_THRESHOLD, | |
| raw_caption=raw_caption | |
| ) | |
| def get_model_instance(vocabulary): | |
| model = EncoderDecoderNet( | |
| features_size=config.FEATURES_SIZE, | |
| embed_size=config.EMBED_SIZE, | |
| hidden_size=config.HIDDEN_SIZE, | |
| vocabulary=vocabulary, | |
| encoder_checkpoint='./weights/chexnet.pth.tar' | |
| ) | |
| model = model.to(config.DEVICE) | |
| return model | |
| def train_test_split(dataset, test_size=0.25, random_state=44): | |
| train_idx, test_idx = sklearn_train_test_split( | |
| list(range(len(dataset))), | |
| test_size=test_size, | |
| random_state=random_state | |
| ) | |
| return Subset(dataset, train_idx), Subset(dataset, test_idx) | |
| def save_checkpoint(checkpoint): | |
| print('=> Saving checkpoint') | |
| torch.save(checkpoint, config.CHECKPOINT_FILE) | |
| def load_checkpoint(model, optimizer=None): | |
| print('=> Loading checkpoint') | |
| checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| return checkpoint['epoch'] | |
| def can_load_checkpoint(): | |
| return os.path.exists(config.CHECKPOINT_FILE) and config.LOAD_MODEL | |
| def remove_special_chars(text): | |
| re1 = re.compile(r' +') | |
| x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace( | |
| 'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace( | |
| '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace( | |
| ' @-@ ', '-').replace('\\', ' \\ ') | |
| return re1.sub(' ', html.unescape(x1)) | |
| def remove_non_ascii(text): | |
| return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore') | |
| def to_lowercase(text): | |
| return text.lower() | |
| def remove_punctuation(text): | |
| translator = str.maketrans('', '', string.punctuation) | |
| return text.translate(translator) | |
| def replace_numbers(text): | |
| return re.sub(r'\d+', '', text) | |
| def text2words(text): | |
| return word_tokenize(text) | |
| def normalize_text( text): | |
| text = remove_special_chars(text) | |
| text = remove_non_ascii(text) | |
| text = remove_punctuation(text) | |
| text = to_lowercase(text) | |
| text = replace_numbers(text) | |
| return text | |
| def normalize_corpus(corpus): | |
| return [normalize_text(t) for t in corpus] | |