Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from config import TEST_MODE, device, dtype, log | |
| from fairseq2.data.text.text_tokenizer import TextTokenEncoder | |
| from seamless_communication.inference import Translator | |
| import spacy | |
| import re | |
| from datetime import datetime | |
| router = APIRouter() | |
| class TranslateInput(BaseModel): | |
| inputs: list[str] | |
| model: str | |
| src_lang: str | |
| dst_lang: str | |
| class TranslateOutput(BaseModel): | |
| src_lang: str | |
| dst_lang: str | |
| translations: Optional[list[str]] = None | |
| error: Optional[str] = None | |
| def t2tt(inputs: TranslateInput) -> TranslateOutput: | |
| start_time = datetime.now() | |
| fn = t2tt_mapping.get(inputs.model) | |
| if not fn: | |
| return TranslateOutput( | |
| src_lang=inputs.src_lang, | |
| dst_lang=inputs.dst_lang, | |
| error=f'No sentence embeddings model found for {inputs.model}' | |
| ) | |
| try: | |
| translations = fn(**inputs.dict()) | |
| log({ | |
| "task": "sentence_embeddings", | |
| "model": inputs.model, | |
| "start_time": start_time.isoformat(), | |
| "time_taken": (datetime.now() - start_time).total_seconds(), | |
| "inputs": inputs.inputs, | |
| "outputs": translations, | |
| "parameters": { | |
| "src_lang": inputs.src_lang, | |
| "dst_lang": inputs.dst_lang, | |
| }, | |
| }) | |
| loaded_models_last_updated[inputs.model] = datetime.now() | |
| return TranslateOutput(**translations) | |
| except Exception as e: | |
| return TranslateOutput( | |
| src_lang=inputs.src_lang, | |
| dst_lang=inputs.dst_lang, | |
| error=str(e) | |
| ) | |
| cmn_nlp = spacy.load("zh_core_web_sm") | |
| xx_nlp = spacy.load("xx_sent_ud_sm") | |
| unk_re = re.compile(r"\s?<unk>|\s?⁇") | |
| def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'): | |
| if TEST_MODE: | |
| return { | |
| "src_lang": src_lang, | |
| "dst_lang": dst_lang, | |
| "translations": None, | |
| "error": None | |
| } | |
| # Load model | |
| if 'facebook/seamless-m4t-v2-large' in loaded_models: | |
| translator = loaded_models['facebook/seamless-m4t-v2-large'] | |
| else: | |
| translator = Translator( | |
| model_name_or_card="seamlessM4T_v2_large", | |
| vocoder_name_or_card="vocoder_v2", | |
| device=device, | |
| dtype=dtype, | |
| apply_mintox=False, | |
| ) | |
| loaded_models['facebook/seamless-m4t-v2-large'] = translator | |
| def sent_tokenize(text, lang) -> list[str]: | |
| if lang == 'cmn': | |
| return [str(t) for t in cmn_nlp(text).sents] | |
| return [str(t) for t in xx_nlp(text).sents] | |
| def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str: | |
| # Convert text into paragraphs and replace new lines with spaces | |
| lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line] | |
| lines = [item for sublist in lines for item in sublist if item] | |
| # Tokenize and translate | |
| input_tokens = translator.collate([token_encoder(line) for line in lines]) | |
| translations = [ | |
| unk_re.sub("", str(t)) | |
| for t in translator.predict( | |
| input=input_tokens, | |
| task_str="T2TT", | |
| src_lang=src_lang, | |
| tgt_lang=dst_lang, | |
| )[0] | |
| ] | |
| return " ".join(translations) | |
| translations = None | |
| token_encoder = translator.text_tokenizer.create_encoder( | |
| task="translation", lang=src_lang, mode="source", device=translator.device | |
| ) | |
| try: | |
| translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs] | |
| except Exception as e: | |
| print(f"Error translating text: {e}") | |
| return { | |
| "src_lang": src_lang, | |
| "dst_lang": dst_lang, | |
| "translations": translations, | |
| "error": None if translations else "Failed to translate text" | |
| } | |
| # Polling every X minutes to | |
| loaded_models = {} | |
| loaded_models_last_updated = {} | |
| t2tt_mapping = { | |
| 'facebook/seamless-m4t-v2-large': seamless_t2tt, | |
| } |