Terry Zhang
commited on
Commit
·
f7c276d
1
Parent(s):
873e38f
fix
Browse files- tasks/text.py +6 -6
tasks/text.py
CHANGED
|
@@ -118,20 +118,20 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
| 118 |
|
| 119 |
def moe_classifier(test_dataset: dict, model: str):
|
| 120 |
print("Starting MoE run")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
texts = test_dataset["quote"]
|
|
|
|
| 122 |
model_path = f"tasks/text_models/0131_MoE_final.pt"
|
| 123 |
|
| 124 |
-
embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
| 125 |
-
embedding_model.to(device)
|
| 126 |
-
|
| 127 |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
| 128 |
|
| 129 |
dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
|
| 130 |
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
|
| 131 |
|
| 132 |
-
# Use CUDA if available
|
| 133 |
-
device, _, _ = get_backend()
|
| 134 |
-
|
| 135 |
model = MoEClassifier(3, 0.05)
|
| 136 |
model.load_state_dict(torch.load(model_path))
|
| 137 |
model = model.to(device)
|
|
|
|
| 118 |
|
| 119 |
def moe_classifier(test_dataset: dict, model: str):
|
| 120 |
print("Starting MoE run")
|
| 121 |
+
|
| 122 |
+
# Use CUDA if available
|
| 123 |
+
device, _, _ = get_backend()
|
| 124 |
+
|
| 125 |
texts = test_dataset["quote"]
|
| 126 |
+
|
| 127 |
model_path = f"tasks/text_models/0131_MoE_final.pt"
|
| 128 |
|
| 129 |
+
embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
|
|
|
|
|
|
| 130 |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
| 131 |
|
| 132 |
dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
|
| 133 |
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
model = MoEClassifier(3, 0.05)
|
| 136 |
model.load_state_dict(torch.load(model_path))
|
| 137 |
model = model.to(device)
|