| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import logging |
| |
|
| | from fairseq2.assets import asset_store, download_manager |
| | from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import ( |
| | SeamlessQualityScorer, |
| | ) |
| | from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent |
| | from seamless_communication.streaming.agents.seamless_streaming_s2st import ( |
| | SeamlessStreamingS2STAgent, |
| | ) |
| | from seamless_communication.streaming.agents.seamless_streaming_s2t import ( |
| | SeamlessStreamingS2TAgent, |
| | ) |
| | from simuleval.cli import evaluate |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser( |
| | add_help=False, |
| | description="Streaming evaluation of Seamless UnitY models", |
| | conflict_handler="resolve", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--task", |
| | choices=["s2st", "s2tt", "asr"], |
| | required=True, |
| | type=str, |
| | help="Target language to translate/transcribe into.", |
| | ) |
| | parser.add_argument( |
| | "--expressive", |
| | action="store_true", |
| | default=False, |
| | help="Expressive streaming S2ST inference", |
| | ) |
| |
|
| | args, _ = parser.parse_known_args() |
| |
|
| | model_configs = dict( |
| | source_segment_size=320, |
| | device="cuda:0", |
| | dtype="fp16", |
| | min_starting_wait_w2vbert=192, |
| | decision_threshold=0.5, |
| | no_early_stop=True, |
| | max_len_a=0, |
| | max_len_b=100, |
| | ) |
| |
|
| | eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER") |
| | if args.task == "s2st": |
| | model_configs["min_unit_chunk_size"] = 50 |
| | eval_configs["latency_metrics"] = "StartOffset EndOffset" |
| |
|
| | if args.expressive: |
| | agent_class = SeamlessS2STAgent |
| | else: |
| | agent_class = SeamlessStreamingS2STAgent |
| | elif args.task in ["s2tt", "asr"]: |
| | assert args.expressive is False, "S2TT inference cannot be expressive." |
| | agent_class = SeamlessStreamingS2TAgent |
| | parser.add_argument( |
| | "--unity-model-name", |
| | type=str, |
| | help="Unity model name.", |
| | default="seamless_streaming_unity", |
| | ) |
| | args, _ = parser.parse_known_args() |
| | asset_card = asset_store.retrieve_card(name=args.unity_model_name) |
| | tokenizer_uri = asset_card.field("tokenizer").as_uri() |
| | tokenizer_path = download_manager.download_tokenizer( |
| | tokenizer_uri, asset_card.name, force=False, progress=True |
| | ) |
| | eval_configs["latency_metrics"] = "AL LAAL" |
| | eval_configs["eval_latency_unit"] = "spm" |
| | eval_configs["eval_latency_spm_model"] = tokenizer_path |
| |
|
| | base_config = dict( |
| | dataloader="fairseq2_s2tt", |
| | dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader", |
| | ) |
| |
|
| | evaluate(agent_class, {**base_config, **model_configs, **eval_configs}, parser) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|