marmarg2 commited on
Commit
c9732f9
·
1 Parent(s): f802df1

Upload Ejemplo de uso RoBERTa-k-MMG.ipynb

Browse files
Files changed (1) hide show
  1. Ejemplo de uso RoBERTa-k-MMG.ipynb +476 -0
Ejemplo de uso RoBERTa-k-MMG.ipynb ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fb0e0d02",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Carga del modelo"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "2eb92ef8",
14
+ "metadata": {},
15
+ "source": [
16
+ "Vamos a ejecutar un ejemplo para ver como funciona el modelo seleccionado, RoBERTa-k-MMGb.\n",
17
+ "Importamos las librerias"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 1,
23
+ "id": "5aeb9c06",
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "2023-08-29 23:55:16.877559: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
31
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
32
+ ]
33
+ }
34
+ ],
35
+ "source": [
36
+ "import tensorflow as tf\n",
37
+ "from tensorflow import keras"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "fd2d4359",
43
+ "metadata": {},
44
+ "source": [
45
+ "Ahora vamos a cargar el modelo con load_model"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 3,
51
+ "id": "f9f8acde",
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "WARNING:tensorflow:`compile()` was not called as part of model loading because the model's `compile()` method is custom. All subclassed Models that have `compile()` overridden should also override `get_compile_config()` and `compile_from_config(config)`. Alternatively, you can call `compile()` manually after loading.\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "model = tf.keras.models.load_model(\"./RoBERTa-k-MMG.keras\")"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "12dda010",
69
+ "metadata": {},
70
+ "source": [
71
+ "Comprobamos si se nos ha cargado el modelo con summary"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 4,
77
+ "id": "6291f88c",
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Model: \"tf_roberta_for_sequence_classification_1\"\n",
85
+ "_________________________________________________________________\n",
86
+ " Layer (type) Output Shape Param # \n",
87
+ "=================================================================\n",
88
+ " roberta (TFRobertaMainLaye multiple 124052736 \n",
89
+ " r) \n",
90
+ " \n",
91
+ " classifier (TFRobertaClass multiple 592130 \n",
92
+ " ificationHead) \n",
93
+ " \n",
94
+ "=================================================================\n",
95
+ "Total params: 124644866 (475.48 MB)\n",
96
+ "Trainable params: 124644866 (475.48 MB)\n",
97
+ "Non-trainable params: 0 (0.00 Byte)\n",
98
+ "_________________________________________________________________\n"
99
+ ]
100
+ }
101
+ ],
102
+ "source": [
103
+ "model.summary()"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "id": "50711238",
109
+ "metadata": {},
110
+ "source": [
111
+ "En los warning aparece que debemos compilar otra vez el modelo para utilizarlo, así que vamos a ello"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 5,
117
+ "id": "03e2faa9",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "model.compile(\n",
122
+ "optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),\n",
123
+ "loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
124
+ "metrics=tf.metrics.SparseCategoricalAccuracy(),\n",
125
+ ")\n"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "cc8aa427",
131
+ "metadata": {},
132
+ "source": [
133
+ "## Comprobación de sus métricas"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "c2badbb3",
139
+ "metadata": {},
140
+ "source": [
141
+ "Ahora vamos a comprobar que efectivamente tiene el loss y el accuracy que se asegura, para ello voy a cargar el conjunto de test"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 7,
147
+ "id": "cd2031f8",
148
+ "metadata": {},
149
+ "outputs": [
150
+ {
151
+ "data": {
152
+ "text/plain": [
153
+ "{'label': 1,\n",
154
+ " 'text': 'Le recorta la comida para que no se ponga gordo o suba de peso'}"
155
+ ]
156
+ },
157
+ "execution_count": 7,
158
+ "metadata": {},
159
+ "output_type": "execute_result"
160
+ }
161
+ ],
162
+ "source": [
163
+ "from datasets import load_dataset\n",
164
+ "data_file = {\"test\": \"test.csv\"}\n",
165
+ "dataset = load_dataset(\"../toxic-teenage-relationships\", data_files=data_file, sep=\";\")\n",
166
+ "dataset['test'][8]"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "id": "04e69690",
172
+ "metadata": {},
173
+ "source": [
174
+ "Cargamos el tokenizador y con .map se pasa al dataset"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 8,
180
+ "id": "541d8cd7",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "#en este ejemplo, utilizamos el AutoTokenizer general\n",
185
+ "from transformers import AutoTokenizer\n",
186
+ "\n",
187
+ "\n",
188
+ "tokenizer = AutoTokenizer.from_pretrained(\"PlanTL-GOB-ES/roberta-base-bne\")\n",
189
+ "\n",
190
+ "\n",
191
+ "def tokenize_function(examples):\n",
192
+ " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
193
+ "\n",
194
+ "#para tokenizar el conjunto de datos, usamos la función map que acelera la tokenización\n",
195
+ "tokenized_datasets = dataset.map(tokenize_function, batched=True)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "id": "fc088df2",
201
+ "metadata": {},
202
+ "source": [
203
+ "Ahora vamos a convertir el dataset en formator de TensorFlow. Para eso usamos DefaultDataCollator, que junta los tensores en un batch para que el modelo se entrene en él. Debemos especificar el argumento return_tensors=\"tf\".\n"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 9,
209
+ "id": "0106a482",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "from transformers import DefaultDataCollator\n",
214
+ "data_collator = DefaultDataCollator(return_tensors=\"tf\")"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 10,
220
+ "id": "814f2aa5",
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "\n",
225
+ "eval_dataset = tokenized_datasets[\"test\"]"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "id": "02dd4812",
231
+ "metadata": {},
232
+ "source": [
233
+ "Ahora vamos a convertir los datasets tokenizados en datasets de TensorFlow con el método .to_tf_dataset. Las entradas están en columns y la etiqueta en label_cols. El bach size es el número de ejemplos que se introducen en la red para que se entrene cada vez"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 11,
239
+ "id": "dc19171f",
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "tf_validation_dataset= eval_dataset.to_tf_dataset(\n",
244
+ "columns=[\"attention_mask\", \"input_ids\"],\n",
245
+ "label_cols=\"labels\",\n",
246
+ "shuffle=False,\n",
247
+ "collate_fn=data_collator,\n",
248
+ "batch_size=8,\n",
249
+ ")"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "id": "a832ba0b",
255
+ "metadata": {},
256
+ "source": [
257
+ "Evaluamos el modelo y, efectivamente, aparecen las cifras que esperábamos"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 12,
263
+ "id": "4d2bef7d",
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "('loss', 0.39804112911224365)\n",
271
+ "('sparse_categorical_accuracy', 0.9090909361839294)\n"
272
+ ]
273
+ }
274
+ ],
275
+ "source": [
276
+ "scores= model.evaluate(tf_validation_dataset, verbose=0)\n",
277
+ "print((model.metrics_names[0], scores[0]))\n",
278
+ "print((model.metrics_names[1], scores[1]))"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "id": "494bee3b",
284
+ "metadata": {},
285
+ "source": [
286
+ "## Predicciones"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "id": "50182958",
292
+ "metadata": {},
293
+ "source": [
294
+ "En primer lugar, leemos el fichero que contiene las predicciones"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 13,
300
+ "id": "0aaf73da",
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "with open('ejemplo.txt', 'r', encoding='utf-8') as file:\n",
305
+ " lines = file.readlines()"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "id": "156713a2",
311
+ "metadata": {},
312
+ "source": [
313
+ "Ahora procedemos a hacer las tokenizar y hacer predicciones"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": 14,
319
+ "id": "7e878eb3",
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "#tokenizamos los datos\n",
324
+ "encoded_data= tokenizer(lines, padding=\"max_length\", truncation=True, return_tensors=\"tf\")\n",
325
+ "#Obtenemos los input_ids, que son la matriz de enteros que representa los tokens tokenizados del texto de entrada\n",
326
+ "input_ids = encoded_data['input_ids']\n",
327
+ "#Obtenemos attention_mask, que es la matrizz que indica que tokens son reales (1) y cuales son relleno (0)\n",
328
+ "attention_mask = encoded_data['attention_mask']\n",
329
+ " "
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "id": "03bb7d36",
335
+ "metadata": {},
336
+ "source": [
337
+ "Ahora vamos a realizar las prediciones, y para eso le pasamos estos dos componentes de los datos tokenizados"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 15,
343
+ "id": "44e5b58d",
344
+ "metadata": {},
345
+ "outputs": [
346
+ {
347
+ "name": "stdout",
348
+ "output_type": "stream",
349
+ "text": [
350
+ "1/1 [==============================] - 25s 25s/step\n"
351
+ ]
352
+ }
353
+ ],
354
+ "source": [
355
+ "predictions = model.predict({'input_ids':input_ids, 'attention_mask':attention_mask})"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "markdown",
360
+ "id": "c3b72027",
361
+ "metadata": {},
362
+ "source": [
363
+ "Obtenemos las probabilidades de haber obtenido una clase positiva (tóxica)"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": 16,
369
+ "id": "5c195a22",
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "probs_toxic = predictions.logits[:,1]"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "markdown",
378
+ "id": "111b7a90",
379
+ "metadata": {},
380
+ "source": [
381
+ "Ahora, para hacer la clasificación, vamos a escribir este umbral"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": 17,
387
+ "id": "13abeef0",
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "threshold=0.5\n",
392
+ "decoded_predictions = [1 if prob >= threshold else 0 for prob in probs_toxic]\n"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "id": "1ae09b99",
398
+ "metadata": {},
399
+ "source": [
400
+ "Finalmente, mostramos las predicciones obtenidas"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 18,
406
+ "id": "c3ae29b8",
407
+ "metadata": {},
408
+ "outputs": [
409
+ {
410
+ "name": "stdout",
411
+ "output_type": "stream",
412
+ "text": [
413
+ "\n",
414
+ "Texto: No se que piensas, pareces tonto.\n",
415
+ " Etiqueta de predicción: Tóxico\n",
416
+ " Probabilidad de Toxicidad: 2.6194\n",
417
+ "====================\n",
418
+ "\n",
419
+ "Texto: Maquillada así pareces una puta.\n",
420
+ " Etiqueta de predicción: Tóxico\n",
421
+ " Probabilidad de Toxicidad: 2.6675\n",
422
+ "====================\n",
423
+ "\n",
424
+ "Texto: Me encanta verte sonreir.\n",
425
+ " Etiqueta de predicción: Sano\n",
426
+ " Probabilidad de Toxicidad: -0.1724\n",
427
+ "====================\n",
428
+ "\n",
429
+ "Texto: Me deja de hablar tres días si hago algo que no le gusta.\n",
430
+ " Etiqueta de predicción: Tóxico\n",
431
+ " Probabilidad de Toxicidad: 2.4431\n",
432
+ "====================\n"
433
+ ]
434
+ }
435
+ ],
436
+ "source": [
437
+ "for line, prediction, prob_toxic in zip(lines, decoded_predictions, probs_toxic):\n",
438
+ " label = \"Tóxico\" if prediction == 1 else \"Sano\"\n",
439
+ " print()\n",
440
+ " print(f\"Texto: {line.strip()}\")\n",
441
+ " print(f\" Etiqueta de predicción: {label}\")\n",
442
+ " print(f\" Probabilidad de Toxicidad: {prob_toxic:.4f}\")\n",
443
+ " print(\"=\"*20)"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": null,
449
+ "id": "f30dbd07",
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": []
453
+ }
454
+ ],
455
+ "metadata": {
456
+ "kernelspec": {
457
+ "display_name": "Python 3 (ipykernel)",
458
+ "language": "python",
459
+ "name": "python3"
460
+ },
461
+ "language_info": {
462
+ "codemirror_mode": {
463
+ "name": "ipython",
464
+ "version": 3
465
+ },
466
+ "file_extension": ".py",
467
+ "mimetype": "text/x-python",
468
+ "name": "python",
469
+ "nbconvert_exporter": "python",
470
+ "pygments_lexer": "ipython3",
471
+ "version": "3.8.13"
472
+ }
473
+ },
474
+ "nbformat": 4,
475
+ "nbformat_minor": 5
476
+ }