Spaces:
Sleeping
Sleeping
| """ | |
| Translator - Text translation using NLLB-200 Distilled | |
| Handles multilingual text translation with batch processing | |
| """ | |
| import logging | |
| import torch | |
| from typing import List, Dict, Any | |
| from src.models.model_manager import ModelManager | |
| logger = logging.getLogger(__name__) | |
| class Translator: | |
| """Handles text translation using NLLB-200""" | |
| # NLLB-200 language codes mapping | |
| LANGUAGE_CODES = { | |
| "english": "eng_Latn", | |
| "hindi": "hin_Deva", | |
| "bengali": "ben_Beng", | |
| "tamil": "tam_Taml", | |
| "telugu": "tel_Telu", | |
| "marathi": "mar_Deva", | |
| "gujarati": "guj_Gujr", | |
| "kannada": "kan_Knda", | |
| "malayalam": "mal_Mlym", | |
| "punjabi": "pan_Guru", | |
| "urdu": "urd_Arab", | |
| "odia": "ory_Orya", | |
| "assamese": "asm_Beng", | |
| "nepali": "npi_Deva", | |
| "sinhala": "sin_Sinh", | |
| "arabic": "arb_Arab", | |
| "french": "fra_Latn", | |
| "spanish": "spa_Latn", | |
| "german": "deu_Latn", | |
| "portuguese": "por_Latn", | |
| "russian": "rus_Cyrl", | |
| "chinese": "zho_Hans", | |
| "japanese": "jpn_Jpan", | |
| "korean": "kor_Hang", | |
| } | |
| def __init__(self): | |
| self.model_manager = ModelManager() | |
| def translate( | |
| self, | |
| text: str, | |
| source_language: str, | |
| target_language: str | |
| ) -> Dict[str, Any]: | |
| """ | |
| Translate text from source to target language | |
| Args: | |
| text: Text to translate | |
| source_language: Source language name or NLLB code | |
| target_language: Target language name or NLLB code | |
| Returns: | |
| Dict with 'translated_text', 'source_lang', 'target_lang' | |
| """ | |
| logger.info(f"Translating from {source_language} to {target_language}") | |
| # Get NLLB codes | |
| src_code = self._get_nllb_code(source_language) | |
| tgt_code = self._get_nllb_code(target_language) | |
| logger.info(f"Using NLLB codes: {src_code} -> {tgt_code}") | |
| model, tokenizer = self.model_manager.get_nllb_model() | |
| device = self.model_manager.get_device() | |
| # Prepare input | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Set source language for translation | |
| tokenizer.src_lang = src_code | |
| # Generate translation | |
| logger.info("Generating translation...") | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.get_lang_id(tgt_code), | |
| max_new_tokens=512, | |
| num_beams=5, | |
| early_stopping=True, | |
| ) | |
| # Decode translation | |
| translated_text = tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True | |
| )[0] | |
| logger.info("Translation complete") | |
| return { | |
| "translated_text": translated_text, | |
| "source_language": src_code, | |
| "target_language": tgt_code, | |
| "source_language_name": source_language, | |
| "target_language_name": target_language, | |
| } | |
| def batch_translate( | |
| self, | |
| texts: List[str], | |
| source_language: str, | |
| target_language: str, | |
| batch_size: int = 4 | |
| ) -> List[str]: | |
| """ | |
| Translate multiple texts with batching for efficiency | |
| Args: | |
| texts: List of texts to translate | |
| source_language: Source language | |
| target_language: Target language | |
| batch_size: Batch size for processing | |
| Returns: | |
| List of translated texts | |
| """ | |
| logger.info(f"Batch translating {len(texts)} texts") | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| logger.info(f"Processing batch {i//batch_size + 1}") | |
| for text in batch: | |
| result = self.translate(text, source_language, target_language) | |
| results.append(result["translated_text"]) | |
| return results | |
| def _get_nllb_code(self, language: str) -> str: | |
| """ | |
| Convert language name to NLLB-200 code | |
| """ | |
| lang_lower = language.lower() | |
| # Direct code check | |
| if lang_lower in self.LANGUAGE_CODES: | |
| return self.LANGUAGE_CODES[lang_lower] | |
| # Check if already a code | |
| if "_" in lang_lower: | |
| return lang_lower | |
| # Fallback to English | |
| logger.warning(f"Language '{language}' not found, using English") | |
| return self.LANGUAGE_CODES["english"] | |
| def get_supported_languages() -> List[str]: | |
| """Get list of supported languages""" | |
| return list(Translator.LANGUAGE_CODES.keys()) | |