Zero-Shot Classification
sentence-transformers
PyTorch
JAX
ONNX
Safetensors
OpenVINO
Transformers
English
roberta
text-classification
Instructions to use cross-encoder/nli-roberta-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use cross-encoder/nli-roberta-base with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("cross-encoder/nli-roberta-base") sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Transformers
How to use cross-encoder/nli-roberta-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-classification", model="cross-encoder/nli-roberta-base")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("cross-encoder/nli-roberta-base") model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/nli-roberta-base") - Notebooks
- Google Colab
- Kaggle
nreimers commited on
Commit ·
ced69de
1
Parent(s): 23111f3
up
Browse files
README.md
CHANGED
|
@@ -49,4 +49,17 @@ with torch.no_grad():
|
|
| 49 |
label_mapping = ['contradiction', 'entailment', 'neutral']
|
| 50 |
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
|
| 51 |
print(labels)
|
| 52 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
label_mapping = ['contradiction', 'entailment', 'neutral']
|
| 50 |
labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
|
| 51 |
print(labels)
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Zero-Shot Classification
|
| 55 |
+
This model can also be used for zero-shot-classification:
|
| 56 |
+
```
|
| 57 |
+
from transformers import pipeline
|
| 58 |
+
|
| 59 |
+
classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-roberta-base')
|
| 60 |
+
|
| 61 |
+
sent = "Apple just announced the newest iPhone X"
|
| 62 |
+
candidate_labels = ["technology", "sports", "politics"]
|
| 63 |
+
res = classifier(sent, candidate_labels)
|
| 64 |
+
print(res)
|
| 65 |
+
```
|