| import torch | |
| from trainer import CBOW, TextPreProcessor, make_context_vector | |
| if __name__ == "__main__": | |
| artist_names = "data/artist-names-per-row.csv" | |
| model_path = "data/cbow-model-weights" | |
| text = TextPreProcessor(artist_names) | |
| vocab = text.build_vocab() | |
| model = CBOW(vocab) | |
| model.load_state_dict(torch.load(model_path)) | |
| model.eval() | |
| print("Loaded model") | |
| context = ["ana roxanne", "bjork"] | |
| context_vector = make_context_vector(context, model.word_to_ix) | |
| a = model(context_vector) | |
| prediction = model.ix_to_word[torch.argmax(a[0]).item()] | |
| print(f"Context: {context}\n") | |
| print(f"Prediction: {prediction}") | |