| |
| |
| import argparse |
| import os |
| import platform |
| import time |
|
|
| from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer |
| from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer |
| from allennlp.data.vocabulary import Vocabulary |
| from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder |
| from allennlp.modules.token_embedders.embedding import Embedding |
| from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder |
| from allennlp.models.archival import archive_model, load_archive |
| from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder |
| from allennlp.predictors.predictor import Predictor |
| from allennlp.predictors.text_classifier import TextClassifierPredictor |
| import gradio as gr |
| import torch |
|
|
| from project_settings import project_path |
| from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier |
| from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--cn_archive_file", |
| default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(), |
| type=str |
| ) |
| parser.add_argument( |
| "--en_archive_file", |
| default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(), |
| type=str |
| ) |
| parser.add_argument( |
| "--jp_archive_file", |
| default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(), |
| type=str |
| ) |
| parser.add_argument( |
| "--vi_archive_file", |
| default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), |
| type=str |
| ) |
| parser.add_argument( |
| "--predictor_name", |
| default="text_classifier", |
| type=str |
| ) |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| cn_archive = load_archive(archive_file=args.cn_archive_file) |
| cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name) |
| en_archive = load_archive(archive_file=args.en_archive_file) |
| en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name) |
| jp_archive = load_archive(archive_file=args.jp_archive_file) |
| jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name) |
| vi_archive = load_archive(archive_file=args.vi_archive_file) |
| vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name) |
|
|
| predictor_map = { |
| "chinese": cn_predictor, |
| "english": en_predictor, |
| "japanese": jp_predictor, |
| "vietnamese": vi_predictor, |
| } |
|
|
| def fn(text: str, language: str): |
| predictor = predictor_map.get(language, cn_predictor) |
|
|
| json_dict = {'sentence': text} |
| outputs = predictor.predict_json( |
| json_dict |
| ) |
| outputs = predictor._model.decode(outputs) |
| label = outputs['label'][0] |
| prob = outputs['prob'][0] |
| prob = round(prob, 4) |
| return label, prob |
|
|
| description = """ |
| 电销场景意图识别. |
| 语言: 汉语, 英语, 日语, 越南语. |
| 数据集是私有的. |
| |
| model: selfattention-cnn |
| dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent) |
| |
| accuracy: |
| chinese: 0.8002 |
| english: 0.7011 |
| japanese: 0.8154 |
| vietnamese: 0.8168 |
| |
| """ |
| demo = gr.Interface( |
| fn=fn, |
| inputs=[ |
| gr.Text(label="text"), |
| gr.Dropdown( |
| choices=list(sorted(predictor_map.keys())), |
| label="language" |
| ) |
| ], |
| outputs=[gr.Text(label="intent"), gr.Number(label="prob")], |
| examples=[ |
| ["你找谁", "chinese"], |
| ["你是谁啊", "chinese"], |
| ["不好意思我现在很忙", "chinese"], |
| ["对不起, 不需要哈", "chinese"], |
| ["u have got the wrong number", "english"], |
| ["sure, thank a lot", "english"], |
| ["please leave your message for 95688496", "english"], |
| ["yes well", "english"], |
| ["失礼の", "japanese"], |
| ["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"], |
| ["わかんない。", "japanese"], |
| ["に出ることができません", "japanese"], |
| ["À không phải em nha.", "vietnamese"], |
| ["Dạ nhầm số rồi ạ?", "vietnamese"], |
| ["Ừ, cảm ơn em nhá.", "vietnamese"], |
| ["Không, chị không có tiền.", "vietnamese"], |
| ], |
| examples_per_page=50, |
| title="Telemarketing Intent Classification", |
| description=description, |
| ) |
| demo.launch(share=True if platform.system() == "Windows" else False) |
|
|
| return |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|