Upload PIBot Joint BERT model with full reproducibility
Browse files- README.md +52 -97
- __init__.py +1 -0
- __pycache__/modeling_jointbert.cpython-312.pyc +0 -0
- __pycache__/module.cpython-312.pyc +0 -0
- modeling_jointbert.py +125 -0
- module.py +101 -0
README.md
CHANGED
|
@@ -1,120 +1,75 @@
|
|
| 1 |
-
-
|
| 2 |
-
language: es
|
| 3 |
-
tags:
|
| 4 |
-
- intent-classification
|
| 5 |
-
- slot-filling
|
| 6 |
-
- joint-bert
|
| 7 |
-
- spanish
|
| 8 |
-
- economics
|
| 9 |
-
- chile
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
- **Arquitectura**: BERT (dccuchile/bert-base-spanish-wwm-cased)
|
| 26 |
-
- **Idioma**: Español
|
| 27 |
-
- **Task**: pibimacec
|
| 28 |
-
- **Épocas de entrenamiento**: 20.0
|
| 29 |
|
| 30 |
## Uso
|
| 31 |
|
| 32 |
-
###
|
| 33 |
-
|
| 34 |
-
```bash
|
| 35 |
-
pip install torch transformers pytorch-crf
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
### Ejemplo de Uso
|
| 39 |
|
| 40 |
```python
|
| 41 |
-
from
|
| 42 |
-
from modeling_jointbert import JointBERT
|
| 43 |
-
import torch
|
| 44 |
-
|
| 45 |
-
# Cargar modelo y tokenizer
|
| 46 |
-
model_dir = "smenaaliaga/pibot-jointbert-beto" # Cambiar por tu repo
|
| 47 |
-
tokenizer = BertTokenizer.from_pretrained(model_dir)
|
| 48 |
-
|
| 49 |
-
# Cargar labels
|
| 50 |
-
intent_labels = ["methodology", "value"]
|
| 51 |
-
slot_labels = ["O", "B-indicator", "I-indicator", "B-period", "I-period", ...]
|
| 52 |
-
|
| 53 |
-
# Inicializar modelo (requiere código personalizado de JointBERT)
|
| 54 |
-
model = JointBERT.from_pretrained(
|
| 55 |
-
model_dir,
|
| 56 |
-
intent_label_lst=intent_labels,
|
| 57 |
-
slot_label_lst=slot_labels
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
# Predecir
|
| 61 |
-
text = "cual fue el imacec de agosto 2024"
|
| 62 |
-
# ... (código de predicción)
|
| 63 |
-
```
|
| 64 |
-
|
| 65 |
-
## Datos de Entrenamiento
|
| 66 |
-
|
| 67 |
-
El modelo fue entrenado en un dataset especializado de consultas sobre:
|
| 68 |
-
- **IMACEC**: Indicador Mensual de Actividad Económica
|
| 69 |
-
- **PIB**: Producto Interno Bruto
|
| 70 |
-
- Sectores económicos (minería, comercio, industria, etc.)
|
| 71 |
-
- Períodos temporales (meses, trimestres, años)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
- `methodology`: Consultas sobre metodología/definiciones
|
| 78 |
|
| 79 |
-
|
| 80 |
-
- `indicator`: Indicador económico (IMACEC, PIB)
|
| 81 |
-
- `period`: Período temporal
|
| 82 |
-
- `measure_type`: Tipo de medida (variación, índice, etc.)
|
| 83 |
-
- `sector`: Sector económico
|
| 84 |
-
- `series_type`: Tipo de serie (original, desestacionalizada, tendencia-ciclo)
|
| 85 |
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
- Mejor rendimiento en consultas cortas-medianas (< 50 tokens)
|
| 97 |
-
- Puede tener dificultades con consultas muy ambiguas o fuera de dominio
|
| 98 |
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
year = {2025},
|
| 108 |
-
publisher = {Hugging Face},
|
| 109 |
-
howpublished = {\url{https://huggingface.co/smenaaliaga/pibot-jointbert-beto}}
|
| 110 |
-
}
|
| 111 |
-
```
|
| 112 |
|
| 113 |
## Licencia
|
| 114 |
|
| 115 |
-
|
| 116 |
|
| 117 |
-
##
|
| 118 |
|
| 119 |
-
|
| 120 |
-
- Implementación base: [JointBERT](https://github.com/monologg/JointBERT)
|
|
|
|
| 1 |
+
# PIBot Joint BERT - 7 Heads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
Modelo Joint BERT para clasificación multi-cabeza de consultas sobre indicadores económicos.
|
| 4 |
|
| 5 |
+
## Cabezas de Clasificación
|
| 6 |
|
| 7 |
+
El modelo predice simultáneamente 7 atributos:
|
| 8 |
+
- **indicator**: Indicador económico (ej: imacec, pib)
|
| 9 |
+
- **metric_type**: Tipo de métrica (ej: index, level)
|
| 10 |
+
- **calc_mode**: Modo de cálculo (ej: yoy, mom)
|
| 11 |
+
- **seasonal**: Ajuste estacional (ej: sa, nsa)
|
| 12 |
+
- **req_form**: Forma de solicitud (ej: latest, historical)
|
| 13 |
+
- **frequency**: Frecuencia (ej: m, q, a)
|
| 14 |
+
- **activity**: Actividad/Sector (ej: total, agriculture)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
## Uso
|
| 17 |
|
| 18 |
+
### Opción 1: Local (Recomendado para máxima compatibilidad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
```python
|
| 21 |
+
from load_local_model import PIBotPredictor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
predictor = PIBotPredictor("path/to/model")
|
| 24 |
+
result = predictor.predict("cual fue el pib del último trimestre")
|
| 25 |
+
print(result)
|
| 26 |
+
```
|
|
|
|
| 27 |
|
| 28 |
+
### Opción 2: Desde Hugging Face Hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
```python
|
| 31 |
+
from load_local_model import PIBotPredictor
|
| 32 |
|
| 33 |
+
# Descargar y usar
|
| 34 |
+
predictor = PIBotPredictor("username/pibot-jointbert")
|
| 35 |
+
result = predictor.predict("cual fue el imacec")
|
| 36 |
+
print(result)
|
| 37 |
+
```
|
| 38 |
|
| 39 |
+
### Línea de comandos
|
| 40 |
|
| 41 |
+
```bash
|
| 42 |
+
python load_local_model.py --model_dir path/to/model --text "tu consulta"
|
| 43 |
+
```
|
| 44 |
|
| 45 |
+
## Estructura del Checkpoint
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
```
|
| 48 |
+
model_dir/
|
| 49 |
+
├── model.safetensors # Pesos del modelo
|
| 50 |
+
├── config.json # Configuración de BERT
|
| 51 |
+
├── training_args.bin # Argumentos de entrenamiento
|
| 52 |
+
├── tokenizer.json # Tokenizer rápido
|
| 53 |
+
├── tokenizer_config.json
|
| 54 |
+
├── vocab.txt
|
| 55 |
+
├── modeling_jointbert.py # Arquitectura custom
|
| 56 |
+
├── module.py # Clasificadores custom
|
| 57 |
+
├── __init__.py
|
| 58 |
+
├── *_label.txt # Labels para cada cabeza (7 archivos)
|
| 59 |
+
└── README.md
|
| 60 |
+
```
|
| 61 |
|
| 62 |
+
## Detalles Técnicos
|
| 63 |
|
| 64 |
+
- **Base Model**: dccuchile/bert-base-spanish-wwm-cased (BETO)
|
| 65 |
+
- **Framework**: PyTorch + Transformers
|
| 66 |
+
- **Formato de pesos**: SafeTensors
|
| 67 |
+
- **Tokenizer**: AutoTokenizer con use_fast=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
## Licencia
|
| 70 |
|
| 71 |
+
[Especificar licencia]
|
| 72 |
|
| 73 |
+
## Autor
|
| 74 |
|
| 75 |
+
[Tu información]
|
|
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modeling_jointbert import JointBERT
|
__pycache__/modeling_jointbert.cpython-312.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
__pycache__/module.cpython-312.pyc
ADDED
|
Binary file (5.58 kB). View file
|
|
|
modeling_jointbert.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Use:
|
| 3 |
+
python load_local_model.py --model_dir model_out/pibot_model_v3 --text "cual fue el pib del ultimo trimestre"
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import BertPreTrainedModel, BertModel, BertConfig
|
| 9 |
+
from torchcrf import CRF
|
| 10 |
+
from module import IndicatorClassifier, MetricTypeClassifier, CalcModeClassifier, SeasonalClassifier, ReqFormClassifier, FrequencyClassifier, ActivityClassifier #, SlotClassifier
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class JointBERT(BertPreTrainedModel):
|
| 14 |
+
def __init__(self, config, args, indicator_label_lst, metric_type_label_lst, calc_mode_label_lst,
|
| 15 |
+
seasonal_label_lst, req_form_label_lst, frequency_label_lst, activity_label_lst): #, slot_label_lst):
|
| 16 |
+
super(JointBERT, self).__init__(config)
|
| 17 |
+
self.args = args
|
| 18 |
+
|
| 19 |
+
self.num_indicator_labels = len(indicator_label_lst)
|
| 20 |
+
self.num_metric_type_labels = len(metric_type_label_lst)
|
| 21 |
+
self.num_calc_mode_labels = len(calc_mode_label_lst)
|
| 22 |
+
self.num_seasonal_labels = len(seasonal_label_lst)
|
| 23 |
+
self.num_req_form_labels = len(req_form_label_lst)
|
| 24 |
+
self.num_frequency_labels = len(frequency_label_lst)
|
| 25 |
+
self.num_activity_labels = len(activity_label_lst)
|
| 26 |
+
# self.num_slot_labels = len(slot_label_lst)
|
| 27 |
+
|
| 28 |
+
self.bert = BertModel(config=config) # Load pretrained bert
|
| 29 |
+
|
| 30 |
+
self.indicator_classifier = IndicatorClassifier(config.hidden_size, self.num_indicator_labels, args.dropout_rate)
|
| 31 |
+
self.metric_type_classifier = MetricTypeClassifier(config.hidden_size, self.num_metric_type_labels, args.dropout_rate)
|
| 32 |
+
self.calc_mode_classifier = CalcModeClassifier(config.hidden_size, self.num_calc_mode_labels, args.dropout_rate)
|
| 33 |
+
self.seasonal_classifier = SeasonalClassifier(config.hidden_size, self.num_seasonal_labels, args.dropout_rate)
|
| 34 |
+
self.req_form_classifier = ReqFormClassifier(config.hidden_size, self.num_req_form_labels, args.dropout_rate)
|
| 35 |
+
self.frequency_classifier = FrequencyClassifier(config.hidden_size, self.num_frequency_labels, args.dropout_rate)
|
| 36 |
+
self.activity_classifier = ActivityClassifier(config.hidden_size, self.num_activity_labels, args.dropout_rate)
|
| 37 |
+
# self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
|
| 38 |
+
|
| 39 |
+
# if args.use_crf:
|
| 40 |
+
# self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
|
| 41 |
+
|
| 42 |
+
def forward(self, input_ids, attention_mask, token_type_ids=None, indicator_label_ids=None, metric_type_label_ids=None,
|
| 43 |
+
calc_mode_label_ids=None, seasonal_label_ids=None, req_form_label_ids=None, frequency_label_ids=None, activity_label_ids=None): #, slot_labels_ids=None):
|
| 44 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask,
|
| 45 |
+
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 46 |
+
sequence_output = outputs[0]
|
| 47 |
+
pooled_output = outputs[1] # [CLS]
|
| 48 |
+
|
| 49 |
+
indicator_logits = self.indicator_classifier(pooled_output)
|
| 50 |
+
metric_type_logits = self.metric_type_classifier(pooled_output)
|
| 51 |
+
calc_mode_logits = self.calc_mode_classifier(pooled_output)
|
| 52 |
+
seasonal_logits = self.seasonal_classifier(pooled_output)
|
| 53 |
+
req_form_logits = self.req_form_classifier(pooled_output)
|
| 54 |
+
frequency_logits = self.frequency_classifier(pooled_output)
|
| 55 |
+
activity_logits = self.activity_classifier(pooled_output)
|
| 56 |
+
# slot_logits = self.slot_classifier(sequence_output)
|
| 57 |
+
|
| 58 |
+
total_loss = 0
|
| 59 |
+
# 1. Indicator CrossEntropy
|
| 60 |
+
if indicator_label_ids is not None:
|
| 61 |
+
indicator_loss_fct = nn.CrossEntropyLoss()
|
| 62 |
+
indicator_loss = indicator_loss_fct(indicator_logits.view(-1, self.num_indicator_labels), indicator_label_ids.view(-1))
|
| 63 |
+
total_loss += indicator_loss
|
| 64 |
+
|
| 65 |
+
# 2. Metric Type CrossEntropy
|
| 66 |
+
if metric_type_label_ids is not None:
|
| 67 |
+
metric_type_loss_fct = nn.CrossEntropyLoss()
|
| 68 |
+
metric_type_loss = metric_type_loss_fct(metric_type_logits.view(-1, self.num_metric_type_labels), metric_type_label_ids.view(-1))
|
| 69 |
+
total_loss += metric_type_loss
|
| 70 |
+
|
| 71 |
+
# 3. Calc Mode CrossEntropy
|
| 72 |
+
if calc_mode_label_ids is not None:
|
| 73 |
+
calc_mode_loss_fct = nn.CrossEntropyLoss()
|
| 74 |
+
calc_mode_loss = calc_mode_loss_fct(calc_mode_logits.view(-1, self.num_calc_mode_labels), calc_mode_label_ids.view(-1))
|
| 75 |
+
total_loss += calc_mode_loss
|
| 76 |
+
|
| 77 |
+
# 4. Seasonal CrossEntropy
|
| 78 |
+
if seasonal_label_ids is not None:
|
| 79 |
+
seasonal_loss_fct = nn.CrossEntropyLoss()
|
| 80 |
+
seasonal_loss = seasonal_loss_fct(seasonal_logits.view(-1, self.num_seasonal_labels), seasonal_label_ids.view(-1))
|
| 81 |
+
total_loss += seasonal_loss
|
| 82 |
+
|
| 83 |
+
# 5. Req Form CrossEntropy
|
| 84 |
+
if req_form_label_ids is not None:
|
| 85 |
+
req_form_loss_fct = nn.CrossEntropyLoss()
|
| 86 |
+
req_form_loss = req_form_loss_fct(req_form_logits.view(-1, self.num_req_form_labels), req_form_label_ids.view(-1))
|
| 87 |
+
total_loss += req_form_loss
|
| 88 |
+
|
| 89 |
+
# 6. Frequency CrossEntropy
|
| 90 |
+
if frequency_label_ids is not None:
|
| 91 |
+
frequency_loss_fct = nn.CrossEntropyLoss()
|
| 92 |
+
frequency_loss = frequency_loss_fct(frequency_logits.view(-1, self.num_frequency_labels), frequency_label_ids.view(-1))
|
| 93 |
+
total_loss += frequency_loss
|
| 94 |
+
|
| 95 |
+
# 7. Activity CrossEntropy
|
| 96 |
+
if activity_label_ids is not None:
|
| 97 |
+
activity_loss_fct = nn.CrossEntropyLoss()
|
| 98 |
+
activity_loss = activity_loss_fct(activity_logits.view(-1, self.num_activity_labels), activity_label_ids.view(-1))
|
| 99 |
+
total_loss += activity_loss
|
| 100 |
+
|
| 101 |
+
# # 8. Slot Softmax
|
| 102 |
+
# if slot_labels_ids is not None and self.args.slot_loss_coef != 0:
|
| 103 |
+
# if self.args.use_crf:
|
| 104 |
+
# # CRF doesn't handle ignore_index (-100), so we replace it with PAD (0)
|
| 105 |
+
# slot_labels_ids_crf = slot_labels_ids.clone()
|
| 106 |
+
# slot_labels_ids_crf[slot_labels_ids_crf == self.args.ignore_index] = 0
|
| 107 |
+
# slot_loss = self.crf(slot_logits, slot_labels_ids_crf, mask=attention_mask.bool(), reduction='mean')
|
| 108 |
+
# slot_loss = -1 * slot_loss # negative log-likelihood
|
| 109 |
+
# else:
|
| 110 |
+
# slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
| 111 |
+
# # Only keep active parts of the loss
|
| 112 |
+
# if attention_mask is not None:
|
| 113 |
+
# active_loss = attention_mask.view(-1) == 1
|
| 114 |
+
# active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
| 115 |
+
# active_labels = slot_labels_ids.view(-1)[active_loss]
|
| 116 |
+
# slot_loss = slot_loss_fct(active_logits, active_labels)
|
| 117 |
+
# else:
|
| 118 |
+
# slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
| 119 |
+
# total_loss += self.args.slot_loss_coef * slot_loss
|
| 120 |
+
|
| 121 |
+
outputs = ((indicator_logits, metric_type_logits, calc_mode_logits, seasonal_logits, req_form_logits, frequency_logits, activity_logits),) + outputs[2:] # add hidden states and attention if they are here #, slot_logits
|
| 122 |
+
|
| 123 |
+
outputs = (total_loss,) + outputs
|
| 124 |
+
|
| 125 |
+
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of all classifier logits
|
module.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
# class IntentClassifier(nn.Module):
|
| 4 |
+
# def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
|
| 5 |
+
# super(IntentClassifier, self).__init__()
|
| 6 |
+
# self.dropout = nn.Dropout(dropout_rate)
|
| 7 |
+
# self.linear = nn.Linear(input_dim, num_intent_labels)
|
| 8 |
+
|
| 9 |
+
# def forward(self, x):
|
| 10 |
+
# x = self.dropout(x)
|
| 11 |
+
# return self.linear(x)
|
| 12 |
+
|
| 13 |
+
class IndicatorClassifier(nn.Module):
|
| 14 |
+
def __init__(self, input_dim, num_indicator_labels, dropout_rate=0.):
|
| 15 |
+
super(IndicatorClassifier, self).__init__()
|
| 16 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 17 |
+
self.linear = nn.Linear(input_dim, num_indicator_labels)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = self.dropout(x)
|
| 21 |
+
return self.linear(x)
|
| 22 |
+
|
| 23 |
+
class MetricTypeClassifier(nn.Module):
|
| 24 |
+
def __init__(self, input_dim, num_metric_type_labels, dropout_rate=0.):
|
| 25 |
+
super(MetricTypeClassifier, self).__init__()
|
| 26 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 27 |
+
self.linear = nn.Linear(input_dim, num_metric_type_labels)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = self.dropout(x)
|
| 31 |
+
return self.linear(x)
|
| 32 |
+
|
| 33 |
+
class SeasonalClassifier(nn.Module):
|
| 34 |
+
def __init__(self, input_dim, num_seasonal_labels, dropout_rate=0.):
|
| 35 |
+
super(SeasonalClassifier, self).__init__()
|
| 36 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 37 |
+
self.linear = nn.Linear(input_dim, num_seasonal_labels)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
x = self.dropout(x)
|
| 41 |
+
return self.linear(x)
|
| 42 |
+
|
| 43 |
+
class ActivityClassifier(nn.Module):
|
| 44 |
+
def __init__(self, input_dim, num_activity_labels, dropout_rate=0.):
|
| 45 |
+
super(ActivityClassifier, self).__init__()
|
| 46 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 47 |
+
self.linear = nn.Linear(input_dim, num_activity_labels)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = self.dropout(x)
|
| 51 |
+
return self.linear(x)
|
| 52 |
+
|
| 53 |
+
class FrequencyClassifier(nn.Module):
|
| 54 |
+
def __init__(self, input_dim, num_frequency_labels, dropout_rate=0.):
|
| 55 |
+
super(FrequencyClassifier, self).__init__()
|
| 56 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 57 |
+
self.linear = nn.Linear(input_dim, num_frequency_labels)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x = self.dropout(x)
|
| 61 |
+
return self.linear(x)
|
| 62 |
+
|
| 63 |
+
class CalcModeClassifier(nn.Module):
|
| 64 |
+
def __init__(self, input_dim, num_calc_mode_labels, dropout_rate=0.):
|
| 65 |
+
super(CalcModeClassifier, self).__init__()
|
| 66 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 67 |
+
self.linear = nn.Linear(input_dim, num_calc_mode_labels)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = self.dropout(x)
|
| 71 |
+
return self.linear(x)
|
| 72 |
+
|
| 73 |
+
class ReqFormClassifier(nn.Module):
|
| 74 |
+
def __init__(self, input_dim, num_req_form_labels, dropout_rate=0.):
|
| 75 |
+
super(ReqFormClassifier, self).__init__()
|
| 76 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 77 |
+
self.linear = nn.Linear(input_dim, num_req_form_labels)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
x = self.dropout(x)
|
| 81 |
+
return self.linear(x)
|
| 82 |
+
|
| 83 |
+
# class ContextModeClassifier(nn.Module):
|
| 84 |
+
# def __init__(self, input_dim, num_context_mode_labels, dropout_rate=0.):
|
| 85 |
+
# super(ContextModeClassifier, self).__init__()
|
| 86 |
+
# self.dropout = nn.Dropout(dropout_rate)
|
| 87 |
+
# self.linear = nn.Linear(input_dim, num_context_mode_labels)
|
| 88 |
+
|
| 89 |
+
# def forward(self, x):
|
| 90 |
+
# x = self.dropout(x)
|
| 91 |
+
# return self.linear(x)
|
| 92 |
+
|
| 93 |
+
# class SlotClassifier(nn.Module):
|
| 94 |
+
# def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
|
| 95 |
+
# super(SlotClassifier, self).__init__()
|
| 96 |
+
# self.dropout = nn.Dropout(dropout_rate)
|
| 97 |
+
# self.linear = nn.Linear(input_dim, num_slot_labels)
|
| 98 |
+
|
| 99 |
+
# def forward(self, x):
|
| 100 |
+
# x = self.dropout(x)
|
| 101 |
+
# return self.linear(x)
|