#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json import os from pathlib import Path import platform import tempfile import time from typing import List, Dict import zipfile from cacheout import Cache import gradio as gr import huggingface_hub import numpy as np import torch from project_settings import project_path, environment from toolbox.torch.utils.data.tokenizers.pretrained_bert_tokenizer import PretrainedBertTokenizer from toolbox.torch.utils.data.vocabulary import Vocabulary def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--waba_intent_examples_file", default=(project_path / "waba_intent_examples.json").as_posix(), type=str ) parser.add_argument( "--waba_intent_md_file", default=(project_path / "waba_intent.md").as_posix(), type=str ) args = parser.parse_args() return args model_cache = Cache(maxsize=256, ttl=1 * 60, timer=time.time) def load_waba_intent_model(repo_id: str): model_local_dir = project_path / "trained_models/{}".format(repo_id) model_local_dir.mkdir(parents=True, exist_ok=True) hf_token = environment.get("hf_token") huggingface_hub.login(token=hf_token) huggingface_hub.snapshot_download( repo_id=repo_id, local_dir=model_local_dir ) model = torch.jit.load((model_local_dir / "final.zip").as_posix()) vocabulary = Vocabulary.from_files((model_local_dir / "vocabulary").as_posix()) tokenizer = PretrainedBertTokenizer(model_local_dir.as_posix()) result = { "model": model, "vocabulary": vocabulary, "tokenizer": tokenizer, } return result def click_waba_intent_button(repo_id: str, text: str): model_group = model_cache.get(repo_id) if model_group is None: model_group = load_waba_intent_model(repo_id) model_cache.set(key=repo_id, value=model_group) model = model_group["model"] vocabulary = model_group["vocabulary"] tokenizer = model_group["tokenizer"] tokens: List[str] = tokenizer.tokenize(text) tokens: List[int] = [vocabulary.get_token_index(token, namespace="tokens") for token in tokens] if len(tokens) < 5: tokens = vocabulary.pad_or_truncate_ids_by_max_length(tokens, max_length=5) batch_tokens = [tokens] batch_tokens = torch.from_numpy(np.array(batch_tokens)) outputs = model.forward(batch_tokens) probs = outputs["probs"] argmax = torch.argmax(probs, dim=-1) probs = probs.tolist()[0] argmax = argmax.tolist()[0] label_str = vocabulary.get_token_from_index(argmax, namespace="labels") prob = probs[argmax] prob = round(prob, 4) return label_str, prob def main(): args = get_args() brief_description = """ ## Text Classification """ # examples with open(args.waba_intent_examples_file, "r", encoding="utf-8") as f: waba_intent_examples = json.load(f) with open(args.waba_intent_md_file, "r", encoding="utf-8") as f: waba_intent_md = f.read() with gr.Blocks() as blocks: gr.Markdown(value=brief_description) with gr.Row(): with gr.Column(scale=5): with gr.Tabs(): with gr.TabItem("waba_intent"): with gr.Row(): with gr.Column(scale=1): waba_intent_repo_id = gr.Dropdown( choices=["nxcloud/waba_intent_en"], value="nxcloud/waba_intent_en", label="repo_id" ) waba_intent_text = gr.Textbox(label="text", max_lines=5) waba_intent_button = gr.Button("predict", variant="primary") with gr.Column(scale=1): waba_intent_label = gr.Textbox(label="label") waba_intent_prob = gr.Textbox(label="prob") # examples gr.Examples( examples=waba_intent_examples, inputs=[ waba_intent_repo_id, waba_intent_text, ], outputs=[ waba_intent_label, waba_intent_prob ], fn=click_waba_intent_button ) # md gr.Markdown(value=waba_intent_md) # click event waba_intent_button.click( fn=click_waba_intent_button, inputs=[ waba_intent_repo_id, waba_intent_text, ], outputs=[ waba_intent_label, waba_intent_prob ], ) blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=7860 ) return if __name__ == '__main__': main()