| import os | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from pipeline import DialogPipeline | |
| from demo_app import create_demo | |
| MODEL_REPO = "sakeef/bangla-dialog-models" | |
| # Download NLU checkpoint (single .pt file) and the T5 model directory from the | |
| # Hub on first startup. After that they are cached and reused across restarts. | |
| nlu_ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="nlu/best_model.pt") | |
| nlu_dir = os.path.dirname(nlu_ckpt) | |
| nlg_dir = snapshot_download(repo_id=MODEL_REPO, allow_patterns="nlg/*") | |
| nlg_dir = os.path.join(nlg_dir, "nlg") | |
| pipeline = DialogPipeline( | |
| nlu_model_path=nlu_dir, | |
| gen_model_path=nlg_dir, | |
| labels_dir="./labels", | |
| bert_model="sagorsarker/bangla-bert-base", | |
| ) | |
| demo = create_demo(pipeline) | |
| demo.launch() | |