AudioDubbAi / src /core /translator.py
vasugo05's picture
Upload 24 files
fad5c32 verified
"""
Translator - Text translation using NLLB-200 Distilled
Handles multilingual text translation with batch processing
"""
import logging
import torch
from typing import List, Dict, Any
from src.models.model_manager import ModelManager
logger = logging.getLogger(__name__)
class Translator:
"""Handles text translation using NLLB-200"""
# NLLB-200 language codes mapping
LANGUAGE_CODES = {
"english": "eng_Latn",
"hindi": "hin_Deva",
"bengali": "ben_Beng",
"tamil": "tam_Taml",
"telugu": "tel_Telu",
"marathi": "mar_Deva",
"gujarati": "guj_Gujr",
"kannada": "kan_Knda",
"malayalam": "mal_Mlym",
"punjabi": "pan_Guru",
"urdu": "urd_Arab",
"odia": "ory_Orya",
"assamese": "asm_Beng",
"nepali": "npi_Deva",
"sinhala": "sin_Sinh",
"arabic": "arb_Arab",
"french": "fra_Latn",
"spanish": "spa_Latn",
"german": "deu_Latn",
"portuguese": "por_Latn",
"russian": "rus_Cyrl",
"chinese": "zho_Hans",
"japanese": "jpn_Jpan",
"korean": "kor_Hang",
}
def __init__(self):
self.model_manager = ModelManager()
def translate(
self,
text: str,
source_language: str,
target_language: str
) -> Dict[str, Any]:
"""
Translate text from source to target language
Args:
text: Text to translate
source_language: Source language name or NLLB code
target_language: Target language name or NLLB code
Returns:
Dict with 'translated_text', 'source_lang', 'target_lang'
"""
logger.info(f"Translating from {source_language} to {target_language}")
# Get NLLB codes
src_code = self._get_nllb_code(source_language)
tgt_code = self._get_nllb_code(target_language)
logger.info(f"Using NLLB codes: {src_code} -> {tgt_code}")
model, tokenizer = self.model_manager.get_nllb_model()
device = self.model_manager.get_device()
# Prepare input
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Set source language for translation
tokenizer.src_lang = src_code
# Generate translation
logger.info("Generating translation...")
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(tgt_code),
max_new_tokens=512,
num_beams=5,
early_stopping=True,
)
# Decode translation
translated_text = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=True
)[0]
logger.info("Translation complete")
return {
"translated_text": translated_text,
"source_language": src_code,
"target_language": tgt_code,
"source_language_name": source_language,
"target_language_name": target_language,
}
def batch_translate(
self,
texts: List[str],
source_language: str,
target_language: str,
batch_size: int = 4
) -> List[str]:
"""
Translate multiple texts with batching for efficiency
Args:
texts: List of texts to translate
source_language: Source language
target_language: Target language
batch_size: Batch size for processing
Returns:
List of translated texts
"""
logger.info(f"Batch translating {len(texts)} texts")
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
logger.info(f"Processing batch {i//batch_size + 1}")
for text in batch:
result = self.translate(text, source_language, target_language)
results.append(result["translated_text"])
return results
def _get_nllb_code(self, language: str) -> str:
"""
Convert language name to NLLB-200 code
"""
lang_lower = language.lower()
# Direct code check
if lang_lower in self.LANGUAGE_CODES:
return self.LANGUAGE_CODES[lang_lower]
# Check if already a code
if "_" in lang_lower:
return lang_lower
# Fallback to English
logger.warning(f"Language '{language}' not found, using English")
return self.LANGUAGE_CODES["english"]
@staticmethod
def get_supported_languages() -> List[str]:
"""Get list of supported languages"""
return list(Translator.LANGUAGE_CODES.keys())