File size: 777 Bytes
4286de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os

from huggingface_hub import hf_hub_download, snapshot_download

from pipeline import DialogPipeline
from demo_app import create_demo

MODEL_REPO = "sakeef/bangla-dialog-models"

# Download NLU checkpoint (single .pt file) and the T5 model directory from the
# Hub on first startup. After that they are cached and reused across restarts.
nlu_ckpt = hf_hub_download(repo_id=MODEL_REPO, filename="nlu/best_model.pt")
nlu_dir = os.path.dirname(nlu_ckpt)

nlg_dir = snapshot_download(repo_id=MODEL_REPO, allow_patterns="nlg/*")
nlg_dir = os.path.join(nlg_dir, "nlg")

pipeline = DialogPipeline(
    nlu_model_path=nlu_dir,
    gen_model_path=nlg_dir,
    labels_dir="./labels",
    bert_model="sagorsarker/bangla-bert-base",
)

demo = create_demo(pipeline)
demo.launch()