| | |
| |
|
| | import os |
| | from typing import List, Optional |
| | from src.simuleval_transcoder import SimulevalTranscoder |
| | import json |
| | import logging |
| |
|
| | logger = logging.getLogger("socketio_server_pubsub") |
| |
|
| | |
| | M4T_P0_LANGS = [ |
| | "eng", |
| | "arb", "ben", "cat", "ces", "cmn", "cym", "dan", |
| | "deu", "est", "fin", "fra", "hin", "ind", "ita", |
| | "jpn", "kor", "mlt", "nld", "pes", "pol", "por", |
| | "ron", "rus", "slk", "spa", "swe", "swh", "tel", |
| | "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", |
| | ] |
| | |
| |
|
| |
|
| | class NoAvailableAgentException(Exception): |
| | pass |
| |
|
| |
|
| | class AgentWithInfo: |
| | def __init__( |
| | self, |
| | agent, |
| | name: str, |
| | modalities: List[str], |
| | target_langs: List[str], |
| | |
| | dynamic_params: List[str] = [], |
| | description="", |
| | has_expressive: Optional[bool] = None, |
| | ): |
| | self.agent = agent |
| | self.has_expressive = has_expressive |
| | self.name = name |
| | self.description = description |
| | self.modalities = modalities |
| | self.target_langs = target_langs |
| | self.dynamic_params = dynamic_params |
| |
|
| | def get_capabilities_for_json(self): |
| | return { |
| | "name": self.name, |
| | "description": self.description, |
| | "modalities": self.modalities, |
| | "targetLangs": self.target_langs, |
| | "dynamicParams": self.dynamic_params, |
| | } |
| |
|
| | @classmethod |
| | def load_from_json(cls, config: str): |
| | """ |
| | Takes in JSON array of models to load in, e.g. |
| | [{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}, |
| | {"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}] |
| | """ |
| | configs = json.loads(config) |
| | agents = [] |
| | for config in configs: |
| | agent = SimulevalTranscoder.build_agent(config["name"]) |
| | agents.append( |
| | AgentWithInfo( |
| | agent=agent, |
| | name=config["name"], |
| | modalities=config["modalities"], |
| | target_langs=config["targetLangs"], |
| | ) |
| | ) |
| | return agents |
| |
|
| |
|
| | class SimulevalAgentDirectory: |
| | |
| | seamless_streaming_agent = "SeamlessStreaming" |
| | seamless_agent = "Seamless" |
| |
|
| | def __init__(self): |
| | self.agents = [] |
| | self.did_build_and_add_agents = False |
| |
|
| | def add_agent(self, agent: AgentWithInfo): |
| | self.agents.append(agent) |
| |
|
| | def build_agent_if_available(self, model_id, config_name=None): |
| | agent = None |
| | try: |
| | if config_name is not None: |
| | agent = SimulevalTranscoder.build_agent( |
| | model_id, |
| | config_name=config_name, |
| | ) |
| | else: |
| | agent = SimulevalTranscoder.build_agent( |
| | model_id, |
| | ) |
| | except Exception as e: |
| | from fairseq2.assets.error import AssetError |
| | logger.warning("Failed to build agent %s: %s" % (model_id, e)) |
| | if isinstance(e, AssetError): |
| | logger.warning( |
| | "Please download gated assets and set `gated_model_dir` in the config" |
| | ) |
| | raise e |
| |
|
| | return agent |
| |
|
| | def build_and_add_agents(self, models_override=None): |
| | if self.did_build_and_add_agents: |
| | return |
| |
|
| | if models_override is not None: |
| | agent_infos = AgentWithInfo.load_from_json(models_override) |
| | for agent_info in agent_infos: |
| | self.add_agent(agent_info) |
| | else: |
| | s2s_agent = None |
| | if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1": |
| | logger.info("Building expressive model...") |
| | s2s_agent = self.build_agent_if_available( |
| | SimulevalAgentDirectory.seamless_agent, |
| | config_name="vad_s2st_sc_24khz_main.yaml", |
| | ) |
| | has_expressive = True |
| | else: |
| | logger.info("Building non-expressive model...") |
| | s2s_agent = self.build_agent_if_available( |
| | SimulevalAgentDirectory.seamless_streaming_agent, |
| | config_name="vad_s2st_sc_main.yaml", |
| | ) |
| | has_expressive = False |
| |
|
| | if s2s_agent: |
| | self.add_agent( |
| | AgentWithInfo( |
| | agent=s2s_agent, |
| | name=SimulevalAgentDirectory.seamless_streaming_agent, |
| | modalities=["s2t", "s2s"], |
| | target_langs=M4T_P0_LANGS, |
| | dynamic_params=["expressive"], |
| | description="multilingual expressive model that supports S2S and S2T", |
| | has_expressive=has_expressive, |
| | ) |
| | ) |
| |
|
| | if len(self.agents) == 0: |
| | logger.error( |
| | "No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory." |
| | ) |
| |
|
| | self.did_build_and_add_agents = True |
| |
|
| | def get_agent(self, name): |
| | for agent in self.agents: |
| | if agent.name == name: |
| | return agent |
| | return None |
| |
|
| | def get_agent_or_throw(self, name): |
| | agent = self.get_agent(name) |
| | if agent is None: |
| | raise NoAvailableAgentException("No agent found with name= %s" % (name)) |
| | return agent |
| |
|
| | def get_agents_capabilities_list_for_json(self): |
| | return [agent.get_capabilities_for_json() for agent in self.agents] |
| |
|