File size: 1,063 Bytes
901d29e
 
 
 
 
157af11
901d29e
 
157af11
 
 
901d29e
157af11
 
 
 
901d29e
157af11
 
901d29e
157af11
901d29e
157af11
 
 
 
 
 
901d29e
157af11
 
 
 
 
 
 
 
901d29e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
---
library_name: transformers
tags: []
---

# USAGE


```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

MODEL_NAME = "swarogthehater/IMAGE_INTENT"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, load_in_8bit=True)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

text = "show me yourself"
IMAGE_INTENT = "[IMG]"

input_text = text+"[SEP]"+IMAGE_INTENT+"[SEP]"

device = torch.device("cpu")
batch = tokenizer.encode_plus(input_text, return_tensors="pt")
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
token_type_ids = batch['token_type_ids'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

#labels in common nli terms (0: entailment, 1: neutral, 2: contradiction)
print(outputs.logits.argmax().item())
#labels for img intent 
label = 1 if outputs.logits.argmax().item() == 0 else 0
print(label)
#scores
print(outputs.logits.float().softmax(dim=-1).detach().numpy())
```