andre156 commited on
Commit
446d7b4
·
verified ·
1 Parent(s): 00eb080

Upload inferenceNotebook.ipynb

Browse files
Files changed (1) hide show
  1. inferenceNotebook.ipynb +145 -0
inferenceNotebook.ipynb ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e12b9784-0a73-447c-bd95-5c4db12213ec",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Load "
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 11,
14
+ "id": "94c34109-799b-4094-934b-85df33a3be99",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import transformers\n",
19
+ "import pandas as pd\n",
20
+ "import numpy as np\n",
21
+ "import torch\n",
22
+ "from transformers import BertTokenizer\n",
23
+ "\n",
24
+ "# Path of bert model\n",
25
+ "path = '/home/colombo_phd/ItalianLaws/Data/BERT-Domains/'\n",
26
+ "\n",
27
+ "# label df to convert token to string\n",
28
+ "label = pd.read_csv(path +'label_tokens.csv', sep = ';')\n",
29
+ "\n",
30
+ "# Load model\n",
31
+ "if torch.cuda.is_available():\n",
32
+ " model = torch.load('bert_model')\n",
33
+ "else:\n",
34
+ " model = torch.load(path +'bert_model', map_location=torch.device('cpu'))"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "f8df905b-9a7b-46ec-8aab-adb15b50aad5",
40
+ "metadata": {},
41
+ "source": [
42
+ "## String to evaluate - title of the law"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 19,
48
+ "id": "5868b342-3161-4862-b269-1d4959359d48",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "title = 'Regolamento per il commercio di prodotti agricoli in europa'"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "id": "3c2d173a-4702-4fb3-93f8-c1ea366bdc41",
58
+ "metadata": {},
59
+ "source": [
60
+ "## Run model"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 20,
66
+ "id": "d57af582-61bc-4be9-b305-63a40ede1311",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
71
+ "encoded_dict = tokenizer.encode_plus(\n",
72
+ " title,\n",
73
+ " add_special_tokens = True,\n",
74
+ " max_length = 389,\n",
75
+ " truncation=True,\n",
76
+ " pad_to_max_length = True,\n",
77
+ " return_attention_mask = True,\n",
78
+ " return_tensors = 'pt',\n",
79
+ " )\n",
80
+ "test_input_ids = torch.cat([encoded_dict['input_ids']], dim=0)\n",
81
+ "test_attention_masks = torch.cat([encoded_dict['attention_mask']], dim=0)\n",
82
+ "\n",
83
+ "b_input_ids = test_input_ids.to(device)\n",
84
+ "b_input_mask = test_attention_masks.to(device)\n",
85
+ "with torch.no_grad():\n",
86
+ " output= model(b_input_ids,\n",
87
+ " token_type_ids=None,\n",
88
+ " attention_mask=b_input_mask)\n",
89
+ " logits = output.logits\n",
90
+ " logits = logits.detach().cpu().numpy()\n",
91
+ " pred_flat = np.argmax(logits, axis=1).flatten()"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "markdown",
96
+ "id": "bbf3114a-c0ad-411c-b35b-1d7ec922035d",
97
+ "metadata": {},
98
+ "source": [
99
+ "## Derive domain"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 31,
105
+ "id": "c2e3004b-d735-4462-bbda-6bdb02586102",
106
+ "metadata": {},
107
+ "outputs": [
108
+ {
109
+ "data": {
110
+ "text/plain": [
111
+ "'economia'"
112
+ ]
113
+ },
114
+ "execution_count": 31,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "label[label['label']== pred_flat[0]]['Ministries'].iloc[0]"
121
+ ]
122
+ }
123
+ ],
124
+ "metadata": {
125
+ "kernelspec": {
126
+ "display_name": "Python 3 (ipykernel)",
127
+ "language": "python",
128
+ "name": "python3"
129
+ },
130
+ "language_info": {
131
+ "codemirror_mode": {
132
+ "name": "ipython",
133
+ "version": 3
134
+ },
135
+ "file_extension": ".py",
136
+ "mimetype": "text/x-python",
137
+ "name": "python",
138
+ "nbconvert_exporter": "python",
139
+ "pygments_lexer": "ipython3",
140
+ "version": "3.10.14"
141
+ }
142
+ },
143
+ "nbformat": 4,
144
+ "nbformat_minor": 5
145
+ }