| |
| |
| |
| |
|
|
| import logging |
| from fairseq.tasks import register_task |
| from fairseq.tasks.speech_to_text import SpeechToTextTask |
| from fairseq.tasks.translation import TranslationTask, TranslationConfig |
|
|
| try: |
| import examples.simultaneous_translation |
|
|
| import_successful = True |
| except BaseException: |
| import_successful = False |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def check_import(flag): |
| if not flag: |
| raise ImportError( |
| "'examples.simultaneous_translation' is not correctly imported. " |
| "Please considering `pip install -e $FAIRSEQ_DIR`." |
| ) |
|
|
|
|
| @register_task("simul_speech_to_text") |
| class SimulSpeechToTextTask(SpeechToTextTask): |
| def __init__(self, args, tgt_dict): |
| check_import(import_successful) |
| super().__init__(args, tgt_dict) |
|
|
|
|
| @register_task("simul_text_to_text", dataclass=TranslationConfig) |
| class SimulTextToTextTask(TranslationTask): |
| def __init__(self, cfg, src_dict, tgt_dict): |
| check_import(import_successful) |
| super().__init__(cfg, src_dict, tgt_dict) |
|
|