Update README.md
Browse files
README.md
CHANGED
|
@@ -9,7 +9,7 @@ We introduce the model for multilabel ESG risks classification. There is 47 clas
|
|
| 9 |
|
| 10 |
## Usage
|
| 11 |
```python
|
| 12 |
-
|
| 13 |
from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
|
| 14 |
import torch
|
| 15 |
#Mean Pooling - Take attention mask into account for correct averaging
|
|
@@ -45,10 +45,11 @@ class ESGify(MPNetPreTrainedModel):
|
|
| 45 |
outputs = self.mpnet(input_ids=input_ids,
|
| 46 |
attention_mask=attention_mask)
|
| 47 |
|
| 48 |
-
# mean pooling dataset
|
| 49 |
logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
return logits
|
| 53 |
|
| 54 |
model = ESGify.from_pretrained('ai-lab/ESGify')
|
|
|
|
| 9 |
|
| 10 |
## Usage
|
| 11 |
```python
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
|
| 14 |
import torch
|
| 15 |
#Mean Pooling - Take attention mask into account for correct averaging
|
|
|
|
| 45 |
outputs = self.mpnet(input_ids=input_ids,
|
| 46 |
attention_mask=attention_mask)
|
| 47 |
|
| 48 |
+
# mean pooling dataset and eed input to classifier to compute logits
|
| 49 |
logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
|
| 50 |
+
|
| 51 |
+
# apply sigmoid
|
| 52 |
+
logits = 1.0 / (1.0 + torch.exp(-logits))
|
| 53 |
return logits
|
| 54 |
|
| 55 |
model = ESGify.from_pretrained('ai-lab/ESGify')
|