| from typing import Dict, List, Any |
|
|
| from punctuators.models.punc_cap_seg_model import PunctCapSegConfigONNX, PunctCapSegModelONNX |
|
|
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path: str): |
| cfg: PunctCapSegConfigONNX = PunctCapSegConfigONNX( |
| directory=path, |
| spe_filename="spe_unigram_64k_lowercase_47lang.model", |
| model_filename="punct_cap_seg_47lang.onnx", |
| config_filename="config.yaml", |
| ) |
| self._punctuator: PunctCapSegModelONNX = PunctCapSegModelONNX(cfg) |
|
|
| def __call__(self, data: str) -> List[Dict]: |
| |
| pred_texts: List[List[str]] = self._punctuator.infer([data]) |
| |
| outputs: List[Dict] = [{"generated_text": " \\n ".join(pred_texts[0])}] |
| return outputs |
|
|