smenaaliaga commited on
Commit
4a7e280
·
verified ·
1 Parent(s): 460d62d

Upload PIBot Joint BERT model with full reproducibility

Browse files
README.md CHANGED
@@ -1,120 +1,75 @@
1
- ---
2
- language: es
3
- tags:
4
- - intent-classification
5
- - slot-filling
6
- - joint-bert
7
- - spanish
8
- - economics
9
- - chile
10
- license: mit
11
- ---
12
 
13
- # PIBot Joint BERT - BETO
14
 
15
- Modelo Joint BERT entrenado para clasificación de intención y extracción de entidades (slot filling) en consultas sobre indicadores económicos del Banco Central de Chile.
16
 
17
- ## Descripción del Modelo
18
-
19
- Este modelo está basado en la arquitectura Joint BERT que realiza simultáneamente:
20
- 1. **Clasificación de Intención**: Determina si la consulta busca valores (`value`) o información metodológica (`methodology`)
21
- 2. **Slot Filling**: Identifica y extrae entidades como indicadores, períodos, tipos de medida, sectores, etc.
22
-
23
- ### Modelo Base
24
-
25
- - **Arquitectura**: BERT (dccuchile/bert-base-spanish-wwm-cased)
26
- - **Idioma**: Español
27
- - **Task**: pibimacec
28
- - **Épocas de entrenamiento**: 20.0
29
 
30
  ## Uso
31
 
32
- ### Instalación
33
-
34
- ```bash
35
- pip install torch transformers pytorch-crf
36
- ```
37
-
38
- ### Ejemplo de Uso
39
 
40
  ```python
41
- from transformers import BertTokenizer
42
- from modeling_jointbert import JointBERT
43
- import torch
44
-
45
- # Cargar modelo y tokenizer
46
- model_dir = "smenaaliaga/pibot-jointbert-beto" # Cambiar por tu repo
47
- tokenizer = BertTokenizer.from_pretrained(model_dir)
48
-
49
- # Cargar labels
50
- intent_labels = ["methodology", "value"]
51
- slot_labels = ["O", "B-indicator", "I-indicator", "B-period", "I-period", ...]
52
-
53
- # Inicializar modelo (requiere código personalizado de JointBERT)
54
- model = JointBERT.from_pretrained(
55
- model_dir,
56
- intent_label_lst=intent_labels,
57
- slot_label_lst=slot_labels
58
- )
59
-
60
- # Predecir
61
- text = "cual fue el imacec de agosto 2024"
62
- # ... (código de predicción)
63
- ```
64
-
65
- ## Datos de Entrenamiento
66
-
67
- El modelo fue entrenado en un dataset especializado de consultas sobre:
68
- - **IMACEC**: Indicador Mensual de Actividad Económica
69
- - **PIB**: Producto Interno Bruto
70
- - Sectores económicos (minería, comercio, industria, etc.)
71
- - Períodos temporales (meses, trimestres, años)
72
 
73
- ### Etiquetas
74
-
75
- **Intenciones:**
76
- - `value`: Consultas sobre valores/datos específicos
77
- - `methodology`: Consultas sobre metodología/definiciones
78
 
79
- **Slots (entidades):**
80
- - `indicator`: Indicador económico (IMACEC, PIB)
81
- - `period`: Período temporal
82
- - `measure_type`: Tipo de medida (variación, índice, etc.)
83
- - `sector`: Sector económico
84
- - `series_type`: Tipo de serie (original, desestacionalizada, tendencia-ciclo)
85
 
86
- ## Rendimiento
 
87
 
88
- - **Intent Accuracy**: ~95%+
89
- - **Slot F1-Score**: ~90%+
 
 
 
90
 
91
- (Valores aproximados, ver logs de entrenamiento para métricas exactas)
92
 
93
- ## Limitaciones
 
 
94
 
95
- - Entrenado específicamente para consultas sobre indicadores económicos chilenos
96
- - Mejor rendimiento en consultas cortas-medianas (< 50 tokens)
97
- - Puede tener dificultades con consultas muy ambiguas o fuera de dominio
98
 
99
- ## Cita
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- Si usas este modelo, por favor cita:
102
 
103
- ```bibtex
104
- @misc{pibot-jointbert,
105
- author = {Banco Central de Chile},
106
- title = {PIBot Joint BERT - Modelo de Clasificación de Intención y Slot Filling},
107
- year = {2025},
108
- publisher = {Hugging Face},
109
- howpublished = {\url{https://huggingface.co/smenaaliaga/pibot-jointbert-beto}}
110
- }
111
- ```
112
 
113
  ## Licencia
114
 
115
- MIT License
116
 
117
- ## Más Información
118
 
119
- - Paper original: [BERT for Joint Intent Classification and Slot Filling](https://arxiv.org/abs/1902.10909)
120
- - Implementación base: [JointBERT](https://github.com/monologg/JointBERT)
 
1
+ # PIBot Joint BERT - 7 Heads
 
 
 
 
 
 
 
 
 
 
2
 
3
+ Modelo Joint BERT para clasificación multi-cabeza de consultas sobre indicadores económicos.
4
 
5
+ ## Cabezas de Clasificación
6
 
7
+ El modelo predice simultáneamente 7 atributos:
8
+ - **indicator**: Indicador económico (ej: imacec, pib)
9
+ - **metric_type**: Tipo de métrica (ej: index, level)
10
+ - **calc_mode**: Modo de cálculo (ej: yoy, mom)
11
+ - **seasonal**: Ajuste estacional (ej: sa, nsa)
12
+ - **req_form**: Forma de solicitud (ej: latest, historical)
13
+ - **frequency**: Frecuencia (ej: m, q, a)
14
+ - **activity**: Actividad/Sector (ej: total, agriculture)
 
 
 
 
15
 
16
  ## Uso
17
 
18
+ ### Opción 1: Local (Recomendado para máxima compatibilidad)
 
 
 
 
 
 
19
 
20
  ```python
21
+ from load_local_model import PIBotPredictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ predictor = PIBotPredictor("path/to/model")
24
+ result = predictor.predict("cual fue el pib del último trimestre")
25
+ print(result)
26
+ ```
 
27
 
28
+ ### Opción 2: Desde Hugging Face Hub
 
 
 
 
 
29
 
30
+ ```python
31
+ from load_local_model import PIBotPredictor
32
 
33
+ # Descargar y usar
34
+ predictor = PIBotPredictor("username/pibot-jointbert")
35
+ result = predictor.predict("cual fue el imacec")
36
+ print(result)
37
+ ```
38
 
39
+ ### Línea de comandos
40
 
41
+ ```bash
42
+ python load_local_model.py --model_dir path/to/model --text "tu consulta"
43
+ ```
44
 
45
+ ## Estructura del Checkpoint
 
 
46
 
47
+ ```
48
+ model_dir/
49
+ ├── model.safetensors # Pesos del modelo
50
+ ├── config.json # Configuración de BERT
51
+ ├── training_args.bin # Argumentos de entrenamiento
52
+ ├── tokenizer.json # Tokenizer rápido
53
+ ├── tokenizer_config.json
54
+ ├── vocab.txt
55
+ ├── modeling_jointbert.py # Arquitectura custom
56
+ ├── module.py # Clasificadores custom
57
+ ├── __init__.py
58
+ ├── *_label.txt # Labels para cada cabeza (7 archivos)
59
+ └── README.md
60
+ ```
61
 
62
+ ## Detalles Técnicos
63
 
64
+ - **Base Model**: dccuchile/bert-base-spanish-wwm-cased (BETO)
65
+ - **Framework**: PyTorch + Transformers
66
+ - **Formato de pesos**: SafeTensors
67
+ - **Tokenizer**: AutoTokenizer con use_fast=True
 
 
 
 
 
68
 
69
  ## Licencia
70
 
71
+ [Especificar licencia]
72
 
73
+ ## Autor
74
 
75
+ [Tu información]
 
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_jointbert import JointBERT
__pycache__/modeling_jointbert.cpython-312.pyc ADDED
Binary file (6.25 kB). View file
 
__pycache__/module.cpython-312.pyc ADDED
Binary file (5.58 kB). View file
 
modeling_jointbert.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Use:
3
+ python load_local_model.py --model_dir model_out/pibot_model_v3 --text "cual fue el pib del ultimo trimestre"
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import BertPreTrainedModel, BertModel, BertConfig
9
+ from torchcrf import CRF
10
+ from module import IndicatorClassifier, MetricTypeClassifier, CalcModeClassifier, SeasonalClassifier, ReqFormClassifier, FrequencyClassifier, ActivityClassifier #, SlotClassifier
11
+
12
+
13
+ class JointBERT(BertPreTrainedModel):
14
+ def __init__(self, config, args, indicator_label_lst, metric_type_label_lst, calc_mode_label_lst,
15
+ seasonal_label_lst, req_form_label_lst, frequency_label_lst, activity_label_lst): #, slot_label_lst):
16
+ super(JointBERT, self).__init__(config)
17
+ self.args = args
18
+
19
+ self.num_indicator_labels = len(indicator_label_lst)
20
+ self.num_metric_type_labels = len(metric_type_label_lst)
21
+ self.num_calc_mode_labels = len(calc_mode_label_lst)
22
+ self.num_seasonal_labels = len(seasonal_label_lst)
23
+ self.num_req_form_labels = len(req_form_label_lst)
24
+ self.num_frequency_labels = len(frequency_label_lst)
25
+ self.num_activity_labels = len(activity_label_lst)
26
+ # self.num_slot_labels = len(slot_label_lst)
27
+
28
+ self.bert = BertModel(config=config) # Load pretrained bert
29
+
30
+ self.indicator_classifier = IndicatorClassifier(config.hidden_size, self.num_indicator_labels, args.dropout_rate)
31
+ self.metric_type_classifier = MetricTypeClassifier(config.hidden_size, self.num_metric_type_labels, args.dropout_rate)
32
+ self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
33
+ self.seasonal_classifier = SeasonalClassifier(config.hidden_size, self.num_seasonal_labels, args.dropout_rate)
34
+ self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
35
+ self.frequency_classifier = FrequencyClassifier(config.hidden_size, self.num_frequency_labels, args.dropout_rate)
36
+ self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
37
+ # self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
38
+
39
+ # if args.use_crf:
40
+ # self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
41
+
42
+ def forward(self, input_ids, attention_mask, token_type_ids=None, indicator_label_ids=None, metric_type_label_ids=None,
43
+ calc_mode_label_ids=None, seasonal_label_ids=None, req_form_label_ids=None, frequency_label_ids=None, activity_label_ids=None): #, slot_labels_ids=None):
44
+ outputs = self.bert(input_ids, attention_mask=attention_mask,
45
+ token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
46
+ sequence_output = outputs[0]
47
+ pooled_output = outputs[1] # [CLS]
48
+
49
+ indicator_logits = self.indicator_classifier(pooled_output)
50
+ metric_type_logits = self.metric_type_classifier(pooled_output)
51
+ calc_mode_logits = self.calc_mode_classifier(pooled_output)
52
+ seasonal_logits = self.seasonal_classifier(pooled_output)
53
+ req_form_logits = self.req_form_classifier(pooled_output)
54
+ frequency_logits = self.frequency_classifier(pooled_output)
55
+ activity_logits = self.activity_classifier(pooled_output)
56
+ # slot_logits = self.slot_classifier(sequence_output)
57
+
58
+ total_loss = 0
59
+ # 1. Indicator CrossEntropy
60
+ if indicator_label_ids is not None:
61
+ indicator_loss_fct = nn.CrossEntropyLoss()
62
+ indicator_loss = indicator_loss_fct(indicator_logits.view(-1, self.num_indicator_labels), indicator_label_ids.view(-1))
63
+ total_loss += indicator_loss
64
+
65
+ # 2. Metric Type CrossEntropy
66
+ if metric_type_label_ids is not None:
67
+ metric_type_loss_fct = nn.CrossEntropyLoss()
68
+ metric_type_loss = metric_type_loss_fct(metric_type_logits.view(-1, self.num_metric_type_labels), metric_type_label_ids.view(-1))
69
+ total_loss += metric_type_loss
70
+
71
+ # 3. Calc Mode CrossEntropy
72
+ if calc_mode_label_ids is not None:
73
+ calc_mode_loss_fct = nn.CrossEntropyLoss()
74
+ calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
75
+ total_loss += calc_mode_loss
76
+
77
+ # 4. Seasonal CrossEntropy
78
+ if seasonal_label_ids is not None:
79
+ seasonal_loss_fct = nn.CrossEntropyLoss()
80
+ seasonal_loss = seasonal_loss_fct(seasonal_logits.view(-1, self.num_seasonal_labels), seasonal_label_ids.view(-1))
81
+ total_loss += seasonal_loss
82
+
83
+ # 5. Req Form CrossEntropy
84
+ if req_form_label_ids is not None:
85
+ req_form_loss_fct = nn.CrossEntropyLoss()
86
+ req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
87
+ total_loss += req_form_loss
88
+
89
+ # 6. Frequency CrossEntropy
90
+ if frequency_label_ids is not None:
91
+ frequency_loss_fct = nn.CrossEntropyLoss()
92
+ frequency_loss = frequency_loss_fct(frequency_logits.view(-1, self.num_frequency_labels), frequency_label_ids.view(-1))
93
+ total_loss += frequency_loss
94
+
95
+ # 7. Activity CrossEntropy
96
+ if activity_label_ids is not None:
97
+ activity_loss_fct = nn.CrossEntropyLoss()
98
+ activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
99
+ total_loss += activity_loss
100
+
101
+ # # 8. Slot Softmax
102
+ # if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
103
+ # if self.args.use_crf:
104
+ # # CRF doesn't handle ignore_index (-100), so we replace it with PAD (0)
105
+ # slot_labels_ids_crf = slot_labels_ids.clone()
106
+ # slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
107
+ # slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
108
+ # slot_loss = -1 * slot_loss # negative log-likelihood
109
+ # else:
110
+ # slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
111
+ # # Only keep active parts of the loss
112
+ # if attention_mask is not None:
113
+ # active_loss = attention_mask.view(-1) == 1
114
+ # active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
115
+ # active_labels = slot_labels_ids.view(-1)[active_loss]
116
+ # slot_loss = slot_loss_fct(active_logits, active_labels)
117
+ # else:
118
+ # slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
119
+ # total_loss += self.args.slot_loss_coef * slot_loss
120
+
121
+ outputs = ((indicator_logits, metric_type_logits, calc_mode_logits, seasonal_logits, req_form_logits, frequency_logits, activity_logits),) + outputs[2:] # add hidden states and attention if they are here #, slot_logits
122
+
123
+ outputs = (total_loss,) + outputs
124
+
125
+ return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of all classifier logits
module.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # class IntentClassifier(nn.Module):
4
+ # def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
5
+ # super(IntentClassifier, self).__init__()
6
+ # self.dropout = nn.Dropout(dropout_rate)
7
+ # self.linear = nn.Linear(input_dim, num_intent_labels)
8
+
9
+ # def forward(self, x):
10
+ # x = self.dropout(x)
11
+ # return self.linear(x)
12
+
13
+ class IndicatorClassifier(nn.Module):
14
+ def __init__(self, input_dim, num_indicator_labels, dropout_rate=0.):
15
+ super(IndicatorClassifier, self).__init__()
16
+ self.dropout = nn.Dropout(dropout_rate)
17
+ self.linear = nn.Linear(input_dim, num_indicator_labels)
18
+
19
+ def forward(self, x):
20
+ x = self.dropout(x)
21
+ return self.linear(x)
22
+
23
+ class MetricTypeClassifier(nn.Module):
24
+ def __init__(self, input_dim, num_metric_type_labels, dropout_rate=0.):
25
+ super(MetricTypeClassifier, self).__init__()
26
+ self.dropout = nn.Dropout(dropout_rate)
27
+ self.linear = nn.Linear(input_dim, num_metric_type_labels)
28
+
29
+ def forward(self, x):
30
+ x = self.dropout(x)
31
+ return self.linear(x)
32
+
33
+ class SeasonalClassifier(nn.Module):
34
+ def __init__(self, input_dim, num_seasonal_labels, dropout_rate=0.):
35
+ super(SeasonalClassifier, self).__init__()
36
+ self.dropout = nn.Dropout(dropout_rate)
37
+ self.linear = nn.Linear(input_dim, num_seasonal_labels)
38
+
39
+ def forward(self, x):
40
+ x = self.dropout(x)
41
+ return self.linear(x)
42
+
43
+ class ActivityClassifier(nn.Module):
44
+ def __init__(self, input_dim, num_activity_labels, dropout_rate=0.):
45
+ super(ActivityClassifier, self).__init__()
46
+ self.dropout = nn.Dropout(dropout_rate)
47
+ self.linear = nn.Linear(input_dim, num_activity_labels)
48
+
49
+ def forward(self, x):
50
+ x = self.dropout(x)
51
+ return self.linear(x)
52
+
53
+ class FrequencyClassifier(nn.Module):
54
+ def __init__(self, input_dim, num_frequency_labels, dropout_rate=0.):
55
+ super(FrequencyClassifier, self).__init__()
56
+ self.dropout = nn.Dropout(dropout_rate)
57
+ self.linear = nn.Linear(input_dim, num_frequency_labels)
58
+
59
+ def forward(self, x):
60
+ x = self.dropout(x)
61
+ return self.linear(x)
62
+
63
+ class CalcModeClassifier(nn.Module):
64
+ def __init__(self, input_dim, num_calc_mode_labels, dropout_rate=0.):
65
+ super(CalcModeClassifier, self).__init__()
66
+ self.dropout = nn.Dropout(dropout_rate)
67
+ self.linear = nn.Linear(input_dim, num_calc_mode_labels)
68
+
69
+ def forward(self, x):
70
+ x = self.dropout(x)
71
+ return self.linear(x)
72
+
73
+ class ReqFormClassifier(nn.Module):
74
+ def __init__(self, input_dim, num_req_form_labels, dropout_rate=0.):
75
+ super(ReqFormClassifier, self).__init__()
76
+ self.dropout = nn.Dropout(dropout_rate)
77
+ self.linear = nn.Linear(input_dim, num_req_form_labels)
78
+
79
+ def forward(self, x):
80
+ x = self.dropout(x)
81
+ return self.linear(x)
82
+
83
+ # class ContextModeClassifier(nn.Module):
84
+ # def __init__(self, input_dim, num_context_mode_labels, dropout_rate=0.):
85
+ # super(ContextModeClassifier, self).__init__()
86
+ # self.dropout = nn.Dropout(dropout_rate)
87
+ # self.linear = nn.Linear(input_dim, num_context_mode_labels)
88
+
89
+ # def forward(self, x):
90
+ # x = self.dropout(x)
91
+ # return self.linear(x)
92
+
93
+ # class SlotClassifier(nn.Module):
94
+ # def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
95
+ # super(SlotClassifier, self).__init__()
96
+ # self.dropout = nn.Dropout(dropout_rate)
97
+ # self.linear = nn.Linear(input_dim, num_slot_labels)
98
+
99
+ # def forward(self, x):
100
+ # x = self.dropout(x)
101
+ # return self.linear(x)