| |
| |
| |
| |
| |
|
|
|
|
| __author__ = "lovemefan" |
| __copyright__ = "Copyright (C) 2023 lovemefan" |
| __license__ = "MIT" |
| __version__ = "v0.0.1" |
|
|
| import logging |
| import threading |
|
|
| from cttpunctuator.src.punctuator import (CT_Transformer, |
| CT_Transformer_VadRealtime) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s", |
| ) |
|
|
| lock = threading.RLock() |
|
|
|
|
| class CttPunctuator: |
| _offline_model = None |
| _online_model = None |
|
|
| def __init__(self, online: bool = False): |
| """ |
| punctuator with singleton pattern |
| :param online: |
| """ |
| self.online = online |
|
|
| if online: |
| if CttPunctuator._online_model is None: |
| with lock: |
| if CttPunctuator._online_model is None: |
| logging.info("Initializing punctuator model with online mode.") |
| CttPunctuator._online_model = CT_Transformer_VadRealtime() |
| self.param_dict = {"cache": []} |
| logging.info("Online model initialized.") |
| self.model = CttPunctuator._online_model |
|
|
| else: |
| if CttPunctuator._offline_model is None: |
| with lock: |
| if CttPunctuator._offline_model is None: |
| logging.info("Initializing punctuator model with offline mode.") |
| CttPunctuator._offline_model = CT_Transformer() |
| logging.info("Offline model initialized.") |
| self.model = CttPunctuator._offline_model |
|
|
| logging.info("Model initialized.") |
|
|
| def punctuate(self, text: str, param_dict=None): |
| if self.online: |
| param_dict = param_dict or self.param_dict |
| return self.model(text, self.param_dict) |
| else: |
| return self.model(text) |
|
|