| | --- |
| | license: apache-2.0 |
| | --- |
| | ## few_shot_intent_gpt2_base |
| |
|
| | 这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果. |
| |
|
| | (1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。 |
| |
|
| | (2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。 |
| |
|
| |
|
| | 最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。 |
| |
|
| | 你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。 |
| |
|
| |
|
| | ### TensorBoard 数集 |
| |
|
| | **Eval Loss** 见下图: |
| |
|
| |  |
| |
|
| |
|
| | **Learning rate** 见下图: |
| |
|
| | 学习率从 2e-4 下降到 1.4e-4。 |
| |
|
| |  |
| |
|
| |
|
| |
|
| |
|
| | ### 讨论 |
| |
|
| | (1)最优解在不到 1 个 epoch 处得到。 |
| |
|
| | * 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。 |
| |
|
| | * 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。 |
| |
|
| | (2)后续应考虑针对 prompt-response 中 response 部分进行训练。 |
| |
|
| | * 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。 |
| |
|
| | (3)模型使用中的体会。 |
| |
|
| | * 如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。 |
| |
|
| | * 如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。 |
| |
|
| |
|
| |
|
| | ### 其它 |
| |
|
| | 训练时加载数据集的代码 |
| | ```python |
| | #!/usr/bin/python3 |
| | # -*- coding: utf-8 -*- |
| | import argparse |
| | import json |
| | |
| | from datasets import load_dataset |
| | from datasets.download.download_manager import DownloadMode |
| | from tqdm import tqdm |
| | |
| | from project_settings import project_path |
| | |
| | |
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str) |
| | parser.add_argument("--dataset_split", default=None, type=str) |
| | parser.add_argument( |
| | "--dataset_cache_dir", |
| | default=(project_path / "hub_datasets").as_posix(), |
| | type=str |
| | ) |
| | |
| | parser.add_argument("--num_epochs", default=1, type=int) |
| | |
| | parser.add_argument("--train_subset", default="train.jsonl", type=str) |
| | parser.add_argument("--valid_subset", default="valid.jsonl", type=str) |
| | args = parser.parse_args() |
| | return args |
| | |
| | |
| | def main(): |
| | args = get_args() |
| | |
| | name_list = [ |
| | # "a_intent_prompt", |
| | "amazon_massive_intent_en_us_prompt", |
| | "amazon_massive_intent_zh_cn_prompt", |
| | "atis_intents_prompt", |
| | "banking77_prompt", |
| | "bi_text11_prompt", |
| | "bi_text27_prompt", |
| | # "book6_prompt", |
| | "carer_prompt", |
| | "chatbots_prompt", |
| | "chinese_news_title_prompt", |
| | "cmid_4class_prompt", |
| | "cmid_36class_prompt", |
| | "coig_cqia_prompt", |
| | "conv_intent_prompt", |
| | "crosswoz_prompt", |
| | "dmslots_prompt", |
| | "dnd_style_intents_prompt", |
| | "emo2019_prompt", |
| | "finance21_prompt", |
| | "ide_intent_prompt", |
| | "intent_classification_prompt", |
| | "jarvis_intent_prompt", |
| | "mobile_assistant_prompt", |
| | "mtop_intent_prompt", |
| | "out_of_scope_prompt", |
| | "ri_sawoz_domain_prompt", |
| | "ri_sawoz_general_prompt", |
| | "small_talk_prompt", |
| | "smp2017_task1_prompt", |
| | "smp2019_task1_domain_prompt", |
| | "smp2019_task1_intent_prompt", |
| | # "snips_built_in_intents_prompt", |
| | "star_wars_prompt", |
| | "suicide_intent_prompt", |
| | "snips_built_in_intents_prompt", |
| | "telemarketing_intent_cn_prompt", |
| | "telemarketing_intent_en_prompt", |
| | "vira_intents_prompt", |
| | ] |
| | |
| | with open(args.train_subset, "w", encoding="utf-8") as f: |
| | for _ in range(args.num_epochs): |
| | for name in name_list: |
| | print(name) |
| | dataset = load_dataset( |
| | path=args.dataset_path, |
| | name=name, |
| | split="train", |
| | cache_dir=args.dataset_cache_dir, |
| | download_mode=DownloadMode.FORCE_REDOWNLOAD, |
| | ignore_verifications=True |
| | ) |
| | for sample in tqdm(dataset): |
| | row = json.dumps(sample, ensure_ascii=False) |
| | f.write("{}\n".format(row)) |
| | |
| | with open(args.valid_subset, "w", encoding="utf-8") as f: |
| | for _ in range(args.num_epochs): |
| | for name in name_list: |
| | print(name) |
| | dataset = load_dataset( |
| | path=args.dataset_path, |
| | name=name, |
| | split="test", |
| | cache_dir=args.dataset_cache_dir, |
| | download_mode=DownloadMode.FORCE_REDOWNLOAD, |
| | ignore_verifications=True |
| | ) |
| | for sample in tqdm(dataset): |
| | row = json.dumps(sample, ensure_ascii=False) |
| | f.write("{}\n".format(row)) |
| | |
| | return |
| | |
| | |
| | if __name__ == '__main__': |
| | main() |
| | |
| | ``` |
| |
|
| |
|