|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- dialogue state tracking |
|
|
- task-oriented dialog |
|
|
|
|
|
--- |
|
|
|
|
|
# roberta-base-trippy-dst-multiwoz21 |
|
|
|
|
|
This is a TripPy model trained on [MultiWOZ 2.1](https://github.com/budzianowski/multiwoz) for use in [ConvLab-3](https://github.com/ConvLab/ConvLab-3). |
|
|
This model predicts informable slots, requestable slots, general actions and domain indicator slots. |
|
|
Expected joint goal accuracy for MultiWOZ 2.1 is in the range of 55-56\%. |
|
|
|
|
|
For information about TripPy DST, refer to [TripPy: A Triple Copy Strategy for Value Independent Neural Dialog State Tracking](https://aclanthology.org/2020.sigdial-1.4/). |
|
|
|
|
|
The training and evaluation code is available at the official [TripPy repository](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). |
|
|
|
|
|
## Training procedure |
|
|
|
|
|
The model was trained on MultiWOZ 2.1 data via supervised learning using the [TripPy codebase](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public). |
|
|
MultiWOZ 2.1 data was loaded via ConvLab-3's unified data format dataloader. |
|
|
The pre-trained encoder is [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta) (base). |
|
|
Fine-tuning the encoder and training the DST specific classification heads was conducted for 10 epochs. |
|
|
|
|
|
### Training hyperparameters |
|
|
|
|
|
``` |
|
|
python3 run_dst.py \ |
|
|
--task_name="unified" \ |
|
|
--model_type="roberta" \ |
|
|
--model_name_or_path="roberta-base" \ |
|
|
--dataset_config=dataset_config/unified_multiwoz21.json \ |
|
|
--do_lower_case \ |
|
|
--learning_rate=1e-4 \ |
|
|
--num_train_epochs=10 \ |
|
|
--max_seq_length=180 \ |
|
|
--per_gpu_train_batch_size=24 \ |
|
|
--per_gpu_eval_batch_size=32 \ |
|
|
--output_dir=results \ |
|
|
--save_epochs=2 \ |
|
|
--eval_all_checkpoints \ |
|
|
--warmup_proportion=0.1 \ |
|
|
--adam_epsilon=1e-6 \ |
|
|
--weight_decay=0.01 \ |
|
|
--fp16 \ |
|
|
--do_train \ |
|
|
--predict_type=dummy \ |
|
|
--seed=42 |
|
|
``` |
|
|
|
|
|
|