Update README.md
Browse files
README.md
CHANGED
|
@@ -25,7 +25,7 @@ The model can be used for zero-shot text classification such sentiment analysis
|
|
| 25 |
The number of labels should be 2 ~ 20.
|
| 26 |
|
| 27 |
### How to use
|
| 28 |
-
You can try the model with the
|
| 29 |
|
| 30 |
```python
|
| 31 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
@@ -35,12 +35,12 @@ tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuni
|
|
| 35 |
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
| 36 |
|
| 37 |
text = "I love this place! The food is always so fresh and delicious."
|
| 38 |
-
list_label = ["negative","
|
| 39 |
|
| 40 |
list_ABC = [x for x in string.ascii_uppercase]
|
| 41 |
-
def add_prefix(text, list_label,
|
| 42 |
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
| 43 |
-
list_label_new = list_label + [tokenizer.pad_token]* (
|
| 44 |
if shuffle:
|
| 45 |
random.shuffle(list_label_new)
|
| 46 |
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
|
@@ -50,9 +50,12 @@ text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
|
|
| 50 |
|
| 51 |
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
|
| 52 |
with torch.no_grad():
|
| 53 |
-
logits = model(**
|
| 54 |
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
| 55 |
predictions = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
| 56 |
```
|
| 57 |
|
| 58 |
|
|
|
|
| 25 |
The number of labels should be 2 ~ 20.
|
| 26 |
|
| 27 |
### How to use
|
| 28 |
+
You can try the model with the Colab [Notebook](https://colab.research.google.com/drive/17bqc8cXFF-wDmZ0o8j7sbrQB9Cq7Gowr?usp=sharing).
|
| 29 |
|
| 30 |
```python
|
| 31 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
| 35 |
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-base")
|
| 36 |
|
| 37 |
text = "I love this place! The food is always so fresh and delicious."
|
| 38 |
+
list_label = ["negative", "positive"]
|
| 39 |
|
| 40 |
list_ABC = [x for x in string.ascii_uppercase]
|
| 41 |
+
def add_prefix(text, list_label, shuffle = False):
|
| 42 |
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
|
| 43 |
+
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
|
| 44 |
if shuffle:
|
| 45 |
random.shuffle(list_label_new)
|
| 46 |
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
|
|
|
|
| 50 |
|
| 51 |
encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
|
| 52 |
with torch.no_grad():
|
| 53 |
+
logits = model(**encoding).logits
|
| 54 |
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
|
| 55 |
predictions = torch.argmax(logits, dim=-1)
|
| 56 |
+
|
| 57 |
+
print(probs)
|
| 58 |
+
print(predictions)
|
| 59 |
```
|
| 60 |
|
| 61 |
|