MishaD commited on
Commit
7d9f96e
·
1 Parent(s): 9d1e5f8

Testing model location

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -40,7 +40,7 @@ def load_models():
40
  model.lm_head = nn.Identity()
41
  tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
42
  my_classifier = SimpleClassifier(768, 768, 1)
43
- weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth")
44
  my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
45
  my_classifier.eval()
46
  return {
 
40
  model.lm_head = nn.Identity()
41
  tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
42
  my_classifier = SimpleClassifier(768, 768, 1)
43
+ weights_path = "twitter_model_91_5-.pth"
44
  my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
45
  my_classifier.eval()
46
  return {