Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| import platform | |
| import tempfile | |
| import time | |
| from typing import List, Dict | |
| import zipfile | |
| from cacheout import Cache | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import torch | |
| from project_settings import project_path, environment | |
| from toolbox.torch.utils.data.tokenizers.pretrained_bert_tokenizer import PretrainedBertTokenizer | |
| from toolbox.torch.utils.data.vocabulary import Vocabulary | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--waba_intent_examples_file", | |
| default=(project_path / "waba_intent_examples.json").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--waba_intent_md_file", | |
| default=(project_path / "waba_intent.md").as_posix(), | |
| type=str | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| model_cache = Cache(maxsize=256, ttl=1 * 60, timer=time.time) | |
| def load_waba_intent_model(repo_id: str): | |
| model_local_dir = project_path / "trained_models/{}".format(repo_id) | |
| model_local_dir.mkdir(parents=True, exist_ok=True) | |
| hf_token = environment.get("hf_token") | |
| huggingface_hub.login(token=hf_token) | |
| huggingface_hub.snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=model_local_dir | |
| ) | |
| model = torch.jit.load((model_local_dir / "final.zip").as_posix()) | |
| vocabulary = Vocabulary.from_files((model_local_dir / "vocabulary").as_posix()) | |
| tokenizer = PretrainedBertTokenizer(model_local_dir.as_posix()) | |
| result = { | |
| "model": model, | |
| "vocabulary": vocabulary, | |
| "tokenizer": tokenizer, | |
| } | |
| return result | |
| def click_waba_intent_button(repo_id: str, text: str): | |
| model_group = model_cache.get(repo_id) | |
| if model_group is None: | |
| model_group = load_waba_intent_model(repo_id) | |
| model_cache.set(key=repo_id, value=model_group) | |
| model = model_group["model"] | |
| vocabulary = model_group["vocabulary"] | |
| tokenizer = model_group["tokenizer"] | |
| tokens: List[str] = tokenizer.tokenize(text) | |
| tokens: List[int] = [vocabulary.get_token_index(token, namespace="tokens") for token in tokens] | |
| if len(tokens) < 5: | |
| tokens = vocabulary.pad_or_truncate_ids_by_max_length(tokens, max_length=5) | |
| batch_tokens = [tokens] | |
| batch_tokens = torch.from_numpy(np.array(batch_tokens)) | |
| outputs = model.forward(batch_tokens) | |
| probs = outputs["probs"] | |
| argmax = torch.argmax(probs, dim=-1) | |
| probs = probs.tolist()[0] | |
| argmax = argmax.tolist()[0] | |
| label_str = vocabulary.get_token_from_index(argmax, namespace="labels") | |
| prob = probs[argmax] | |
| prob = round(prob, 4) | |
| return label_str, prob | |
| def main(): | |
| args = get_args() | |
| brief_description = """ | |
| ## Text Classification | |
| """ | |
| # examples | |
| with open(args.waba_intent_examples_file, "r", encoding="utf-8") as f: | |
| waba_intent_examples = json.load(f) | |
| with open(args.waba_intent_md_file, "r", encoding="utf-8") as f: | |
| waba_intent_md = f.read() | |
| with gr.Blocks() as blocks: | |
| gr.Markdown(value=brief_description) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| with gr.Tabs(): | |
| with gr.TabItem("waba_intent"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| waba_intent_repo_id = gr.Dropdown( | |
| choices=["nxcloud/waba_intent_en"], | |
| value="nxcloud/waba_intent_en", | |
| label="repo_id" | |
| ) | |
| waba_intent_text = gr.Textbox(label="text", max_lines=5) | |
| waba_intent_button = gr.Button("predict", variant="primary") | |
| with gr.Column(scale=1): | |
| waba_intent_label = gr.Textbox(label="label") | |
| waba_intent_prob = gr.Textbox(label="prob") | |
| # examples | |
| gr.Examples( | |
| examples=waba_intent_examples, | |
| inputs=[ | |
| waba_intent_repo_id, | |
| waba_intent_text, | |
| ], | |
| outputs=[ | |
| waba_intent_label, | |
| waba_intent_prob | |
| ], | |
| fn=click_waba_intent_button | |
| ) | |
| # md | |
| gr.Markdown(value=waba_intent_md) | |
| # click event | |
| waba_intent_button.click( | |
| fn=click_waba_intent_button, | |
| inputs=[ | |
| waba_intent_repo_id, | |
| waba_intent_text, | |
| ], | |
| outputs=[ | |
| waba_intent_label, | |
| waba_intent_prob | |
| ], | |
| ) | |
| blocks.queue().launch( | |
| share=False if platform.system() == "Windows" else False, | |
| server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
| server_port=7860 | |
| ) | |
| return | |
| if __name__ == '__main__': | |
| main() | |