Update tasks/text.py
Browse files- tasks/text.py +2 -2
tasks/text.py
CHANGED
|
@@ -106,9 +106,9 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
| 106 |
|
| 107 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 108 |
|
| 109 |
-
if model
|
| 110 |
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
|
| 111 |
-
elif model
|
| 112 |
model = SentenceBERTMultiClass.from_pretrained(model_repo)
|
| 113 |
else:
|
| 114 |
raise(ValueError)
|
|
|
|
| 106 |
|
| 107 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 108 |
|
| 109 |
+
if model in ['bert_base_pruned']:
|
| 110 |
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
|
| 111 |
+
elif model in ['sbert_distilroberta']:
|
| 112 |
model = SentenceBERTMultiClass.from_pretrained(model_repo)
|
| 113 |
else:
|
| 114 |
raise(ValueError)
|