Spaces:
Paused
Paused
File size: 5,317 Bytes
fad5c32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
Translator - Text translation using NLLB-200 Distilled
Handles multilingual text translation with batch processing
"""
import logging
import torch
from typing import List, Dict, Any
from src.models.model_manager import ModelManager
logger = logging.getLogger(__name__)
class Translator:
"""Handles text translation using NLLB-200"""
# NLLB-200 language codes mapping
LANGUAGE_CODES = {
"english": "eng_Latn",
"hindi": "hin_Deva",
"bengali": "ben_Beng",
"tamil": "tam_Taml",
"telugu": "tel_Telu",
"marathi": "mar_Deva",
"gujarati": "guj_Gujr",
"kannada": "kan_Knda",
"malayalam": "mal_Mlym",
"punjabi": "pan_Guru",
"urdu": "urd_Arab",
"odia": "ory_Orya",
"assamese": "asm_Beng",
"nepali": "npi_Deva",
"sinhala": "sin_Sinh",
"arabic": "arb_Arab",
"french": "fra_Latn",
"spanish": "spa_Latn",
"german": "deu_Latn",
"portuguese": "por_Latn",
"russian": "rus_Cyrl",
"chinese": "zho_Hans",
"japanese": "jpn_Jpan",
"korean": "kor_Hang",
}
def __init__(self):
self.model_manager = ModelManager()
def translate(
self,
text: str,
source_language: str,
target_language: str
) -> Dict[str, Any]:
"""
Translate text from source to target language
Args:
text: Text to translate
source_language: Source language name or NLLB code
target_language: Target language name or NLLB code
Returns:
Dict with 'translated_text', 'source_lang', 'target_lang'
"""
logger.info(f"Translating from {source_language} to {target_language}")
# Get NLLB codes
src_code = self._get_nllb_code(source_language)
tgt_code = self._get_nllb_code(target_language)
logger.info(f"Using NLLB codes: {src_code} -> {tgt_code}")
model, tokenizer = self.model_manager.get_nllb_model()
device = self.model_manager.get_device()
# Prepare input
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Move to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Set source language for translation
tokenizer.src_lang = src_code
# Generate translation
logger.info("Generating translation...")
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(tgt_code),
max_new_tokens=512,
num_beams=5,
early_stopping=True,
)
# Decode translation
translated_text = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=True
)[0]
logger.info("Translation complete")
return {
"translated_text": translated_text,
"source_language": src_code,
"target_language": tgt_code,
"source_language_name": source_language,
"target_language_name": target_language,
}
def batch_translate(
self,
texts: List[str],
source_language: str,
target_language: str,
batch_size: int = 4
) -> List[str]:
"""
Translate multiple texts with batching for efficiency
Args:
texts: List of texts to translate
source_language: Source language
target_language: Target language
batch_size: Batch size for processing
Returns:
List of translated texts
"""
logger.info(f"Batch translating {len(texts)} texts")
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
logger.info(f"Processing batch {i//batch_size + 1}")
for text in batch:
result = self.translate(text, source_language, target_language)
results.append(result["translated_text"])
return results
def _get_nllb_code(self, language: str) -> str:
"""
Convert language name to NLLB-200 code
"""
lang_lower = language.lower()
# Direct code check
if lang_lower in self.LANGUAGE_CODES:
return self.LANGUAGE_CODES[lang_lower]
# Check if already a code
if "_" in lang_lower:
return lang_lower
# Fallback to English
logger.warning(f"Language '{language}' not found, using English")
return self.LANGUAGE_CODES["english"]
@staticmethod
def get_supported_languages() -> List[str]:
"""Get list of supported languages"""
return list(Translator.LANGUAGE_CODES.keys())
|