MishaD
commited on
Commit
·
7d9f96e
1
Parent(s):
9d1e5f8
Testing model location
Browse files
model.py
CHANGED
|
@@ -40,7 +40,7 @@ def load_models():
|
|
| 40 |
model.lm_head = nn.Identity()
|
| 41 |
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
| 42 |
my_classifier = SimpleClassifier(768, 768, 1)
|
| 43 |
-
weights_path =
|
| 44 |
my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
|
| 45 |
my_classifier.eval()
|
| 46 |
return {
|
|
|
|
| 40 |
model.lm_head = nn.Identity()
|
| 41 |
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
| 42 |
my_classifier = SimpleClassifier(768, 768, 1)
|
| 43 |
+
weights_path = "twitter_model_91_5-.pth"
|
| 44 |
my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
|
| 45 |
my_classifier.eval()
|
| 46 |
return {
|