import torch from trainer import CBOW, TextPreProcessor, make_context_vector if __name__ == "__main__": artist_names = "data/artist-names-per-row.csv" model_path = "data/cbow-model-weights" text = TextPreProcessor(artist_names) vocab = text.build_vocab() model = CBOW(vocab) model.load_state_dict(torch.load(model_path)) model.eval() print("Loaded model") context = ["ana roxanne", "bjork"] context_vector = make_context_vector(context, model.word_to_ix) a = model(context_vector) prediction = model.ix_to_word[torch.argmax(a[0]).item()] print(f"Context: {context}\n") print(f"Prediction: {prediction}")