Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from functools import partial | |
| import gradio as gr | |
| import torch | |
| from langdetect import detect, LangDetectException | |
| from transformers import MarianMTModel, MarianTokenizer | |
| from utils import get_pytorch_device, spaces_gpu, get_torch_dtype | |
| # Language code mapping to Helsinki-NLP translation models | |
| # If a specific language pair model doesn't exist, we'll use the multilingual model | |
| LANGUAGE_TO_MODEL_MAP = { | |
| "fr": "Helsinki-NLP/opus-mt-fr-en", | |
| "de": "Helsinki-NLP/opus-mt-de-en", | |
| "es": "Helsinki-NLP/opus-mt-es-en", | |
| "it": "Helsinki-NLP/opus-mt-it-en", | |
| "pt": "Helsinki-NLP/opus-mt-pt-en", | |
| "ru": "Helsinki-NLP/opus-mt-ru-en", | |
| "zh": "Helsinki-NLP/opus-mt-zh-en", | |
| "ja": "Helsinki-NLP/opus-mt-ja-en", | |
| "ko": "Helsinki-NLP/opus-mt-ko-en", | |
| "ar": "Helsinki-NLP/opus-mt-ar-en", | |
| "nl": "Helsinki-NLP/opus-mt-nl-en", | |
| "pl": "Helsinki-NLP/opus-mt-pl-en", | |
| "tr": "Helsinki-NLP/opus-mt-tr-en", | |
| "vi": "Helsinki-NLP/opus-mt-vi-en", | |
| "hi": "Helsinki-NLP/opus-mt-hi-en", | |
| "cs": "Helsinki-NLP/opus-mt-cs-en", | |
| "sv": "Helsinki-NLP/opus-mt-sv-en", | |
| "fi": "Helsinki-NLP/opus-mt-fi-en", | |
| "uk": "Helsinki-NLP/opus-mt-uk-en", | |
| "ro": "Helsinki-NLP/opus-mt-ro-en", | |
| "th": "Helsinki-NLP/opus-mt-th-en", | |
| } | |
| def detect_language(text: str) -> str: | |
| """Detect the language of the input text using langdetect library. | |
| Uses the langdetect library, which is a Python port of Google's language-detection | |
| library. It supports over 55 languages and is known for high accuracy, especially | |
| for languages with unique character sets like Korean, Japanese, and Chinese. | |
| Args: | |
| text: Input text to detect the language of. | |
| Returns: | |
| ISO 639-1 language code (e.g., "en", "fr", "de", "ko", "ja") of the detected language. | |
| Raises: | |
| LangDetectException: If the language cannot be detected (e.g., text is too short). | |
| """ | |
| try: | |
| language_code = detect(text) | |
| return language_code | |
| except LangDetectException: | |
| # If detection fails, default to English (will be handled by translation logic) | |
| return "en" | |
| def get_translation_model(language_code: str, fallback_model: str) -> str: | |
| """Get the appropriate translation model for a given language code. | |
| Args: | |
| language_code: ISO 639-1 language code (e.g., "fr", "de", "en"). | |
| fallback_model: Fallback model to use if no specific model is available. | |
| Returns: | |
| Model ID for translation, or fallback model if language not in mapping. | |
| """ | |
| if language_code == "en": | |
| return None # Already in English | |
| return LANGUAGE_TO_MODEL_MAP.get(language_code, fallback_model) | |
| def translate_to_english(fallback_translation_model: str, text: str) -> str: | |
| """Translate text to English using automatic language detection. | |
| First detects the source language using the langdetect library, then selects | |
| the appropriate translation model and translates the text to English using | |
| a local MarianMT model. | |
| Args: | |
| fallback_translation_model: Fallback translation model to use if no | |
| language-specific model is available. | |
| text: Input text to translate to English. | |
| Returns: | |
| String containing the translated text in English, or the original text | |
| if it is already in English. | |
| Note: | |
| - Uses safetensors for secure model loading. | |
| - Automatically selects the best available device (CUDA/XPU/MPS/CPU). | |
| - Cleans up model and GPU memory after inference. | |
| """ | |
| # Detect the language using langdetect library | |
| detected_lang = detect_language(text) | |
| # Check if already in English | |
| if detected_lang == "en": | |
| return text | |
| # Get the appropriate translation model | |
| translation_model = get_translation_model(detected_lang, fallback_translation_model) | |
| # Load model and tokenizer | |
| pytorch_device = get_pytorch_device() | |
| dtype = get_torch_dtype() | |
| # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() | |
| # reduces memory consumption by not storing gradients. This can significantly reduce the | |
| # amount of memory used during the inference phase. | |
| tokenizer = MarianTokenizer.from_pretrained(translation_model) | |
| model = MarianMTModel.from_pretrained( | |
| translation_model, | |
| use_safetensors=True, | |
| dtype=dtype | |
| ).to(pytorch_device) | |
| # Tokenize and translate | |
| inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(pytorch_device) | |
| with torch.no_grad(): | |
| translated = model.generate(**inputs) | |
| translation = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] | |
| # Clean up GPU memory | |
| del model, tokenizer, inputs, translated | |
| if pytorch_device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return translation | |
| def create_translation_tab(fallback_translation_model: str): | |
| """Create the translation to English tab in the Gradio interface. | |
| This function sets up all UI components for translation with automatic | |
| language detection, including input textbox, translate button, and output textbox. | |
| Args: | |
| fallback_translation_model: Fallback translation model to use if no | |
| language-specific model is available. | |
| """ | |
| gr.Markdown("Translate text to English. The source language will be automatically detected.") | |
| translation_input = gr.Textbox(label="Input Text", lines=5) | |
| translation_button = gr.Button("Translate") | |
| translation_output = gr.Textbox(label="Translated Text", lines=5, interactive=False) | |
| translation_button.click( | |
| fn=partial(translate_to_english, fallback_translation_model), | |
| inputs=translation_input, | |
| outputs=translation_output | |
| ) | |