Spaces:
Runtime error
Runtime error
ashish rai
commited on
Commit
·
13315ca
1
Parent(s):
81f2986
updated zeroshot
Browse files- zeroshot_clf.py +2 -2
zeroshot_clf.py
CHANGED
|
@@ -8,7 +8,7 @@ import plotly.express as px
|
|
| 8 |
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
|
| 9 |
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
|
| 10 |
|
| 11 |
-
def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
|
| 12 |
try:
|
| 13 |
labels=labels.split(',')
|
| 14 |
labels=[l.lower() for l in labels]
|
|
@@ -27,7 +27,7 @@ def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokeni
|
|
| 27 |
return_tensors='pt',
|
| 28 |
truncation_strategy='only_first')
|
| 29 |
output = model(input)
|
| 30 |
-
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
|
| 31 |
labels_prob.append(entail_contra_prob)
|
| 32 |
|
| 33 |
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
|
|
|
|
| 8 |
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
|
| 9 |
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
|
| 10 |
|
| 11 |
+
def zero_shot_classification(premise: str, labels: str, model= model, tokenizer= tokenizer):
|
| 12 |
try:
|
| 13 |
labels=labels.split(',')
|
| 14 |
labels=[l.lower() for l in labels]
|
|
|
|
| 27 |
return_tensors='pt',
|
| 28 |
truncation_strategy='only_first')
|
| 29 |
output = model(input)
|
| 30 |
+
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
|
| 31 |
labels_prob.append(entail_contra_prob)
|
| 32 |
|
| 33 |
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
|