ai-building-blocks / translation.py
LiKenun's picture
Switch to use GPU instead of inference client
5c395b2
import gc
from functools import partial
import gradio as gr
import torch
from langdetect import detect, LangDetectException
from transformers import MarianMTModel, MarianTokenizer
from utils import get_pytorch_device, spaces_gpu, get_torch_dtype
# Language code mapping to Helsinki-NLP translation models
# If a specific language pair model doesn't exist, we'll use the multilingual model
LANGUAGE_TO_MODEL_MAP = {
"fr": "Helsinki-NLP/opus-mt-fr-en",
"de": "Helsinki-NLP/opus-mt-de-en",
"es": "Helsinki-NLP/opus-mt-es-en",
"it": "Helsinki-NLP/opus-mt-it-en",
"pt": "Helsinki-NLP/opus-mt-pt-en",
"ru": "Helsinki-NLP/opus-mt-ru-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
"ja": "Helsinki-NLP/opus-mt-ja-en",
"ko": "Helsinki-NLP/opus-mt-ko-en",
"ar": "Helsinki-NLP/opus-mt-ar-en",
"nl": "Helsinki-NLP/opus-mt-nl-en",
"pl": "Helsinki-NLP/opus-mt-pl-en",
"tr": "Helsinki-NLP/opus-mt-tr-en",
"vi": "Helsinki-NLP/opus-mt-vi-en",
"hi": "Helsinki-NLP/opus-mt-hi-en",
"cs": "Helsinki-NLP/opus-mt-cs-en",
"sv": "Helsinki-NLP/opus-mt-sv-en",
"fi": "Helsinki-NLP/opus-mt-fi-en",
"uk": "Helsinki-NLP/opus-mt-uk-en",
"ro": "Helsinki-NLP/opus-mt-ro-en",
"th": "Helsinki-NLP/opus-mt-th-en",
}
def detect_language(text: str) -> str:
"""Detect the language of the input text using langdetect library.
Uses the langdetect library, which is a Python port of Google's language-detection
library. It supports over 55 languages and is known for high accuracy, especially
for languages with unique character sets like Korean, Japanese, and Chinese.
Args:
text: Input text to detect the language of.
Returns:
ISO 639-1 language code (e.g., "en", "fr", "de", "ko", "ja") of the detected language.
Raises:
LangDetectException: If the language cannot be detected (e.g., text is too short).
"""
try:
language_code = detect(text)
return language_code
except LangDetectException:
# If detection fails, default to English (will be handled by translation logic)
return "en"
def get_translation_model(language_code: str, fallback_model: str) -> str:
"""Get the appropriate translation model for a given language code.
Args:
language_code: ISO 639-1 language code (e.g., "fr", "de", "en").
fallback_model: Fallback model to use if no specific model is available.
Returns:
Model ID for translation, or fallback model if language not in mapping.
"""
if language_code == "en":
return None # Already in English
return LANGUAGE_TO_MODEL_MAP.get(language_code, fallback_model)
@spaces_gpu
def translate_to_english(fallback_translation_model: str, text: str) -> str:
"""Translate text to English using automatic language detection.
First detects the source language using the langdetect library, then selects
the appropriate translation model and translates the text to English using
a local MarianMT model.
Args:
fallback_translation_model: Fallback translation model to use if no
language-specific model is available.
text: Input text to translate to English.
Returns:
String containing the translated text in English, or the original text
if it is already in English.
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
"""
# Detect the language using langdetect library
detected_lang = detect_language(text)
# Check if already in English
if detected_lang == "en":
return text
# Get the appropriate translation model
translation_model = get_translation_model(detected_lang, fallback_translation_model)
# Load model and tokenizer
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
tokenizer = MarianTokenizer.from_pretrained(translation_model)
model = MarianMTModel.from_pretrained(
translation_model,
use_safetensors=True,
dtype=dtype
).to(pytorch_device)
# Tokenize and translate
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(pytorch_device)
with torch.no_grad():
translated = model.generate(**inputs)
translation = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
# Clean up GPU memory
del model, tokenizer, inputs, translated
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return translation
def create_translation_tab(fallback_translation_model: str):
"""Create the translation to English tab in the Gradio interface.
This function sets up all UI components for translation with automatic
language detection, including input textbox, translate button, and output textbox.
Args:
fallback_translation_model: Fallback translation model to use if no
language-specific model is available.
"""
gr.Markdown("Translate text to English. The source language will be automatically detected.")
translation_input = gr.Textbox(label="Input Text", lines=5)
translation_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translated Text", lines=5, interactive=False)
translation_button.click(
fn=partial(translate_to_english, fallback_translation_model),
inputs=translation_input,
outputs=translation_output
)