File size: 5,317 Bytes
fad5c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""

Translator - Text translation using NLLB-200 Distilled

Handles multilingual text translation with batch processing

"""

import logging
import torch
from typing import List, Dict, Any
from src.models.model_manager import ModelManager

logger = logging.getLogger(__name__)


class Translator:
    """Handles text translation using NLLB-200"""
    
    # NLLB-200 language codes mapping
    LANGUAGE_CODES = {
        "english": "eng_Latn",
        "hindi": "hin_Deva",
        "bengali": "ben_Beng",
        "tamil": "tam_Taml",
        "telugu": "tel_Telu",
        "marathi": "mar_Deva",
        "gujarati": "guj_Gujr",
        "kannada": "kan_Knda",
        "malayalam": "mal_Mlym",
        "punjabi": "pan_Guru",
        "urdu": "urd_Arab",
        "odia": "ory_Orya",
        "assamese": "asm_Beng",
        "nepali": "npi_Deva",
        "sinhala": "sin_Sinh",
        "arabic": "arb_Arab",
        "french": "fra_Latn",
        "spanish": "spa_Latn",
        "german": "deu_Latn",
        "portuguese": "por_Latn",
        "russian": "rus_Cyrl",
        "chinese": "zho_Hans",
        "japanese": "jpn_Jpan",
        "korean": "kor_Hang",
    }
    
    def __init__(self):
        self.model_manager = ModelManager()
    
    def translate(

        self,

        text: str,

        source_language: str,

        target_language: str

    ) -> Dict[str, Any]:
        """

        Translate text from source to target language

        

        Args:

            text: Text to translate

            source_language: Source language name or NLLB code

            target_language: Target language name or NLLB code

        

        Returns:

            Dict with 'translated_text', 'source_lang', 'target_lang'

        """
        logger.info(f"Translating from {source_language} to {target_language}")
        
        # Get NLLB codes
        src_code = self._get_nllb_code(source_language)
        tgt_code = self._get_nllb_code(target_language)
        
        logger.info(f"Using NLLB codes: {src_code} -> {tgt_code}")
        
        model, tokenizer = self.model_manager.get_nllb_model()
        device = self.model_manager.get_device()
        
        # Prepare input
        inputs = tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        
        # Move to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Set source language for translation
        tokenizer.src_lang = src_code
        
        # Generate translation
        logger.info("Generating translation...")
        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                forced_bos_token_id=tokenizer.get_lang_id(tgt_code),
                max_new_tokens=512,
                num_beams=5,
                early_stopping=True,
            )
        
        # Decode translation
        translated_text = tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True
        )[0]
        
        logger.info("Translation complete")
        
        return {
            "translated_text": translated_text,
            "source_language": src_code,
            "target_language": tgt_code,
            "source_language_name": source_language,
            "target_language_name": target_language,
        }
    
    def batch_translate(

        self,

        texts: List[str],

        source_language: str,

        target_language: str,

        batch_size: int = 4

    ) -> List[str]:
        """

        Translate multiple texts with batching for efficiency

        

        Args:

            texts: List of texts to translate

            source_language: Source language

            target_language: Target language

            batch_size: Batch size for processing

        

        Returns:

            List of translated texts

        """
        logger.info(f"Batch translating {len(texts)} texts")
        
        results = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            logger.info(f"Processing batch {i//batch_size + 1}")
            
            for text in batch:
                result = self.translate(text, source_language, target_language)
                results.append(result["translated_text"])
        
        return results
    
    def _get_nllb_code(self, language: str) -> str:
        """

        Convert language name to NLLB-200 code

        """
        lang_lower = language.lower()
        
        # Direct code check
        if lang_lower in self.LANGUAGE_CODES:
            return self.LANGUAGE_CODES[lang_lower]
        
        # Check if already a code
        if "_" in lang_lower:
            return lang_lower
        
        # Fallback to English
        logger.warning(f"Language '{language}' not found, using English")
        return self.LANGUAGE_CODES["english"]
    
    @staticmethod
    def get_supported_languages() -> List[str]:
        """Get list of supported languages"""
        return list(Translator.LANGUAGE_CODES.keys())