Update README.md
Browse files
README.md
CHANGED
|
@@ -5,36 +5,50 @@ pipeline_tag: zero-shot-classification
|
|
| 5 |
datasets:
|
| 6 |
- HeTree/MevakerConcTree
|
| 7 |
license: apache-2.0
|
| 8 |
-
library_name: transformers
|
| 9 |
---
|
| 10 |
|
| 11 |
# Hebrew Cross-Encoder Model
|
| 12 |
|
| 13 |
## Usage
|
| 14 |
-
|
| 15 |
-
Pre-trained models can be used like this:
|
| 16 |
```python
|
| 17 |
from sentence_transformers import CrossEncoder
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
```
|
| 25 |
|
| 26 |
## Zero-Shot Classification
|
| 27 |
This model can also be used for zero-shot-classification:
|
| 28 |
```python
|
| 29 |
from transformers import pipeline
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
sent = "Apple just announced the newest iPhone X"
|
| 34 |
-
candidate_labels = ["technology", "sports", "politics"]
|
| 35 |
res = classifier(sent, candidate_labels)
|
| 36 |
print(res)
|
| 37 |
-
```
|
| 38 |
|
| 39 |
### Citing
|
| 40 |
|
|
|
|
| 5 |
datasets:
|
| 6 |
- HeTree/MevakerConcTree
|
| 7 |
license: apache-2.0
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
# Hebrew Cross-Encoder Model
|
| 11 |
|
| 12 |
## Usage
|
|
|
|
|
|
|
| 13 |
```python
|
| 14 |
from sentence_transformers import CrossEncoder
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
# Function that applies sigmoid to a score
|
| 18 |
+
def sigmoid(x):
|
| 19 |
+
return 1 / (1 + np.exp(-x))
|
| 20 |
+
|
| 21 |
+
model = CrossEncoder('HeTree/HeCross')
|
| 22 |
|
| 23 |
+
# Scores (already after sigmoid)
|
| 24 |
+
scores = model.predict([('כמה אנשים חיים בברלין?', 'ברלין מונה 3,520,031 תושבים רשומים בשטח של 891.82 קמ"ר.'), ('כמה אנשים חיים בברלין?', 'העיר ניו יורק מפורסמת בזכות מוזיאון המטרופוליטן לאומנות.')])
|
| 25 |
+
print(scores)
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Usage with Transformers AutoModel
|
| 29 |
+
You can use the model also directly with Transformers library (without SentenceTransformers library):
|
| 30 |
+
```python
|
| 31 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 32 |
+
import torch
|
| 33 |
+
model = AutoModelForSequenceClassification.from_pretrained('HeTree/HeCross')
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained('HeTree/HeCross')
|
| 35 |
+
features = tokenizer(['כמה אנשים חיים בברלין?', 'כמה אנשים חיים בברלין?'], ['ברלין מונה 3,520,031 תושבים רשומים בשטח של 891.82 קמ"ר.', 'העיר ניו יורק מפורסמת בזכות מוזיאון המטרופוליטן לאומנות.'], padding=True, truncation=True, return_tensors="pt")
|
| 36 |
+
model.eval()
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
scores = sigmoid(model(**features).logits)
|
| 39 |
+
print(scores)
|
| 40 |
```
|
| 41 |
|
| 42 |
## Zero-Shot Classification
|
| 43 |
This model can also be used for zero-shot-classification:
|
| 44 |
```python
|
| 45 |
from transformers import pipeline
|
| 46 |
+
classifier = pipeline("zero-shot-classification", model='HeTree/HeCross')
|
| 47 |
+
sent = "בשבוע שעבר שדרגתי את גרסת הטלפון שלי ."
|
| 48 |
+
candidate_labels = ["נייד לשיחות", "אתר", "חיוב חשבון", "גישה לחשבון בנק"]
|
|
|
|
|
|
|
| 49 |
res = classifier(sent, candidate_labels)
|
| 50 |
print(res)
|
| 51 |
+
```
|
| 52 |
|
| 53 |
### Citing
|
| 54 |
|