| | |
| | |
| | import argparse |
| | import json |
| |
|
| | from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer |
| | from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer |
| | from allennlp.data.vocabulary import Vocabulary |
| | from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder |
| | from allennlp.modules.token_embedders.embedding import Embedding |
| | from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder |
| | from allennlp.models.archival import archive_model, load_archive |
| | from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder |
| | from allennlp.predictors.predictor import Predictor |
| | from allennlp.predictors.text_classifier import TextClassifierPredictor |
| | import gradio as gr |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from project_settings import project_path |
| | from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier |
| | from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--excel_file", |
| | default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx", |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | "--archive_file", |
| | default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), |
| | type=str |
| | ) |
| | parser.add_argument( |
| | "--predictor_name", |
| | default="text_classifier", |
| | type=str |
| | ) |
| | parser.add_argument( |
| | "--top_k", |
| | default=10, |
| | type=int |
| | ) |
| | parser.add_argument( |
| | "--output_file", |
| | default="intent_top_k.jsonl", |
| | type=str |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| |
|
| | archive = load_archive(archive_file=args.archive_file) |
| | predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name) |
| |
|
| | df = pd.read_excel(args.excel_file) |
| |
|
| | with open(args.output_file, "w", encoding="utf-8") as f: |
| | for i, row in tqdm(df.iterrows(), total=len(df)): |
| | if i < 26976: |
| | continue |
| |
|
| | source = row["source"] |
| | text = row["text"] |
| | label0 = row["label0"] |
| | label1 = row["label1"] |
| | selected = row["selected"] |
| | checked = row["checked"] |
| |
|
| | if pd.isna(source) or source is None: |
| | source = None |
| |
|
| | if pd.isna(text) or text is None: |
| | continue |
| | text = str(text) |
| |
|
| | if pd.isna(label0) or label0 is None: |
| | label0 = None |
| |
|
| | if pd.isna(label1) or label1 is None: |
| | label1 = None |
| |
|
| | if pd.isna(selected) or selected is None: |
| | selected = None |
| | else: |
| | try: |
| | selected = int(selected) |
| | except Exception: |
| | print(type(selected)) |
| | selected = None |
| |
|
| | if pd.isna(checked) or checked is None: |
| | checked = None |
| | else: |
| | try: |
| | checked = int(checked) |
| | except Exception: |
| | print(type(checked)) |
| | checked = None |
| |
|
| | |
| | json_dict = {'sentence': text} |
| | outputs = predictor.predict_json( |
| | json_dict |
| | ) |
| | probs = outputs["probs"] |
| | arg_idx = np.argsort(probs) |
| |
|
| | arg_idx_top_k = arg_idx[-10:] |
| | label_top_k = [ |
| | predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k |
| | ] |
| | prob_top_k = [ |
| | str(round(probs[idx], 5)) for idx in arg_idx_top_k |
| | ] |
| |
|
| | row_ = { |
| | "source": source, |
| | "text": text, |
| | "label0": label0, |
| | "label1": label1, |
| | "selected": selected, |
| | "checked": checked, |
| | "predict_label_top_k": ";".join(list(reversed(label_top_k))), |
| | "predict_prob_top_k": ";".join(list(reversed(prob_top_k))) |
| | } |
| | row_ = json.dumps(row_, ensure_ascii=False) |
| | f.write("{}\n".format(row_)) |
| |
|
| | return |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|