Commit
·
ed4339d
1
Parent(s):
ed29696
Update README.md
Browse files
README.md
CHANGED
|
@@ -95,7 +95,7 @@ class BERTClass(torch.nn.Module):
|
|
| 95 |
output = self.classifier(pooler)
|
| 96 |
return output
|
| 97 |
|
| 98 |
-
def do_predict(model, tokenizer):
|
| 99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
| 100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
| 101 |
test_loader = DataLoader(test_set, **test_params)
|
|
@@ -119,7 +119,7 @@ model_sustain.to(device)
|
|
| 119 |
model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
| 120 |
|
| 121 |
tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
|
| 122 |
-
actual_predictions_sus = do_predict(model_sustain, tokenizer_sus)
|
| 123 |
|
| 124 |
test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
|
| 125 |
```
|
|
|
|
| 95 |
output = self.classifier(pooler)
|
| 96 |
return output
|
| 97 |
|
| 98 |
+
def do_predict(model, tokenizer, test_df):
|
| 99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
| 100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
| 101 |
test_loader = DataLoader(test_set, **test_params)
|
|
|
|
| 119 |
model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
| 120 |
|
| 121 |
tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
|
| 122 |
+
actual_predictions_sus = do_predict(model_sustain, tokenizer_sus, test_df)
|
| 123 |
|
| 124 |
test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
|
| 125 |
```
|