File size: 8,671 Bytes
4d48d5a
fb3dfc3
d0d0352
 
 
 
 
4d48d5a
 
 
 
 
eb52047
4d48d5a
 
eb52047
4d48d5a
eb52047
4d48d5a
 
eb52047
4d48d5a
eb52047
4d48d5a
eb52047
 
 
fb3dfc3
 
 
eb52047
 
fb3dfc3
 
eb52047
 
fb3dfc3
 
 
 
4d48d5a
 
 
 
 
 
 
 
 
eb52047
 
 
 
0bf2d2c
 
 
 
 
 
 
 
 
 
 
 
 
eb52047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d48d5a
 
eb52047
 
 
 
 
 
 
 
 
fb3dfc3
eb52047
4d48d5a
eb52047
 
4d48d5a
 
eb52047
4d48d5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb3dfc3
 
4d48d5a
eb52047
 
4d48d5a
eb52047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d48d5a
eb52047
4d48d5a
 
 
 
 
eb52047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d48d5a
eb52047
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import logging
import os
import re
from typing import Optional

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

logger = logging.getLogger(__name__)

class TranslationModel:
    """
    More efficient translation model that uses smaller models optimized for CPU
    """
    
    def __init__(self, model_cache_dir: str = ".cache/models"):
        """
        Initialize the translation model manager.
        
        Args:
            model_cache_dir: Directory to cache downloaded models
        """
        self.model_cache_dir = model_cache_dir
        self.device = self._get_device()
        self.opus_mt_models = {}  # Cache for loaded OPUS-MT models
        self.fallback_model = None
        self.fallback_tokenizer = None
        self.initialized = False
        self.initialization_error = None
        
        # Create cache directory
        os.makedirs(model_cache_dir, exist_ok=True)
        
        try:
            # Initialize the fallback model (loads when first needed)
            logger.info("TranslationModel initialized - models will be loaded on demand")
            self.initialized = True
        except Exception as e:
            self.initialization_error = str(e)
            logger.error(f"Failed to initialize translation model: {str(e)}")
        
    def _get_device(self):
        """Get the best available device for model inference."""
        if torch.cuda.is_available():
            logger.info("Using CUDA GPU for translation")
            return torch.device("cuda")
        else:
            logger.info("Using CPU for translation")
            return torch.device("cpu")
    
    def _get_opus_mt_model_name(self, source_lang_code: str, target_lang_code: str) -> Optional[str]:
        """Get the appropriate OPUS-MT model name for the language pair."""
        lang_code_mapping = {
            'zh': 'zh',
            'en': 'en',  # unchanged
            'ar': 'ar',
            'fr': 'fr',
            'de': 'de',
            'ru': 'ru',
            'pt': 'pt',
            'es': 'es',  # unchanged
            'it': 'it',
            'nl': 'nl',
            'pl': 'pl',
            'ja': 'ja',
            'ko': 'ko',
        }
        
        source = lang_code_mapping.get(source_lang_code, source_lang_code)
        target = lang_code_mapping.get(target_lang_code, target_lang_code)
        
        # Try direct model first
        model_name = f"Helsinki-NLP/opus-mt-{source}-{target}"
        return model_name
    
    def _load_opus_mt_model(self, source_lang_code: str, target_lang_code: str):
        """Load an OPUS-MT model for the specific language pair."""
        model_name = self._get_opus_mt_model_name(source_lang_code, target_lang_code)
        
        # Check if model already loaded
        key = f"{source_lang_code}-{target_lang_code}"
        if key in self.opus_mt_models:
            return self.opus_mt_models[key]
        
        try:
            logger.info(f"Loading OPUS-MT model: {model_name}")
            
            # Load with half precision to save memory on CPU
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                cache_dir=self.model_cache_dir,
                low_cpu_mem_usage=True
            )
            tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
            
            model.to(self.device)
            logger.info(f"OPUS-MT model loaded successfully: {model_name}")
            
            # Cache the model
            self.opus_mt_models[key] = (model, tokenizer)
            return model, tokenizer
            
        except Exception as e:
            logger.warning(f"Could not load OPUS-MT model {model_name}: {str(e)}")
            return None
    
    def _load_fallback_model(self):
        """Load the fallback NLLB-200 model for language pairs without OPUS-MT models."""
        if self.fallback_model is not None:
            return
            
        try:
            # Use the small distilled version for efficiency on CPU
            model_name = "facebook/nllb-200-distilled-600M"
            logger.info(f"Loading fallback model: {model_name}")
            
            self.fallback_model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                cache_dir=self.model_cache_dir,
                low_cpu_mem_usage=True
            )
            self.fallback_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
            
            self.fallback_model.to(self.device)
            logger.info(f"Fallback model loaded successfully: {model_name}")
            
        except Exception as e:
            logger.error(f"Error loading fallback model: {str(e)}")
            raise
    
    def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str:
        """
        Translate text from source language to target language.
        
        Args:
            text: Text to translate
            source_lang_code: Source language code
            target_lang_code: Target language code
            
        Returns:
            Translated text
        """
        try:
            if not self.initialized:
                raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
            
            # Try to use OPUS-MT model first (faster and often better quality)
            opus_mt_result = self._load_opus_mt_model(source_lang_code, target_lang_code)
            
            if opus_mt_result:
                model, tokenizer = opus_mt_result
                
                inputs = tokenizer(text, return_tensors="pt", padding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
                
                translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
                logger.info(f"Translation completed using OPUS-MT model")
                
            else:
                # Fall back to NLLB model
                logger.info(f"No OPUS-MT model available for {source_lang_code}-{target_lang_code}, using fallback model")
                self._load_fallback_model()
                
                # NLLB uses a specific format for inputs
                tokenizer = self.fallback_tokenizer
                model = self.fallback_model
                
                # Prepare input with NLLB format
                inputs = tokenizer(text, return_tensors="pt", padding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # NLLB language codes are like "eng_Latn", "fra_Latn", etc.
                nllb_source = _get_nllb_code(source_lang_code)
                nllb_target = _get_nllb_code(target_lang_code)
                
                # Force decoder to start with target language token
                forced_bos_token_id = tokenizer.lang_code_to_id[nllb_target]
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        forced_bos_token_id=forced_bos_token_id,
                        max_length=512,
                        num_beams=4,
                        early_stopping=True
                    )
                
                translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
                logger.info(f"Translation completed using fallback NLLB model")
            
            # Clean up the output
            return re.sub(r'\s+', ' ', translated_text).strip()
            
        except Exception as e:
            logger.error(f"Translation error: {str(e)}")
            raise

def _get_nllb_code(lang_code: str) -> str:
    """Convert ISO language code to NLLB language code format."""
    # Mapping for common languages
    nllb_mapping = {
        'en': 'eng_Latn',
        'fr': 'fra_Latn',
        'es': 'spa_Latn',
        'de': 'deu_Latn',
        'it': 'ita_Latn',
        'pt': 'por_Latn',
        'nl': 'nld_Latn',
        'ru': 'rus_Cyrl',
        'zh': 'zho_Hans',
        'ar': 'ara_Arab',
        'hi': 'hin_Deva',
        'ja': 'jpn_Jpan',
        'ko': 'kor_Hang',
    }
    
    return nllb_mapping.get(lang_code, f"{lang_code}_Latn")