File size: 3,471 Bytes
d173119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461b6d6
 
 
 
 
 
 
 
 
 
d173119
461b6d6
d173119
 
 
 
 
e734ddd
461b6d6
 
 
 
 
 
 
 
 
e734ddd
461b6d6
 
 
 
e734ddd
d173119
 
 
 
 
 
 
 
 
e734ddd
d173119
 
 
 
 
 
 
 
 
e734ddd
d173119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5338af6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModel, AutoTokenizer\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3e5e27fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModel.from_pretrained(\"UF-NLPC-Lab/bart-stance-mixed\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eccdbd4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"UF-NLPC-Lab/bart-stance-mixed\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86cf1631",
   "metadata": {},
   "outputs": [],
   "source": [
    "# One claim target, and one noun-phrase target\n",
    "# First prediction should be FAVOR, second should be NONE as the author says nothing about his opinion on salad itself.\n",
    "samples = [\"Carrots are the superior vegetable.\", \"I don't like ranch on my salad.\"]\n",
    "targets = [\"Brocoli is inferior to carrots.\", \"salad\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6965a16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model was trained with target-then-sample. \n",
    "# You can probably get away with tokenizing like this...\n",
    "encoding = tokenizer(text=targets, text_pair=samples, return_tensors='pt', padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "82673e84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ... but this is what we actually did during our own training and evaluation\n",
    "max_context_length = 256\n",
    "max_target_length = 64\n",
    "context_trunc = tokenizer.decode(tokenizer.encode(samples, max_length=max_context_length, add_special_tokens=False), is_split_into_words=False)\n",
    "target_trunc = tokenizer.decode(tokenizer.encode(targets, max_length=max_target_length, add_special_tokens=False), is_split_into_words=False)\n",
    "combined = tokenizer(text=context_trunc, text_pair=target_trunc, return_tensors='pt', padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5d93f032",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(**encoding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6a83f378",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['FAVOR', 'NONE']"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits = output.logits.detach().cpu().numpy()\n",
    "preds = np.argmax(logits, axis=-1)\n",
    "str_preds = [model.config.id2label[p] for p in preds]\n",
    "str_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26496ec3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "min_transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.22"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}