File size: 20,050 Bytes
67f25fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
"""
Translation service using IndicTrans2 by AI4Bharat
Handles language detection and translation between Indian languages
"""

import asyncio
import logging
from typing import Dict, List, Optional, Any
import torch
try:
    import fasttext
    FASTTEXT_AVAILABLE = True
except ImportError:
    FASTTEXT_AVAILABLE = False
    fasttext = None
import os
import requests
from dotenv import load_dotenv
from models import SUPPORTED_LANGUAGES

# Load environment variables
load_dotenv()

# Load environment variables early
load_dotenv()

logger = logging.getLogger(__name__)

# --- Model Configuration ---
FASTTEXT_MODEL_URL = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"
FASTTEXT_MODEL_PATH = os.path.join(os.path.dirname(__file__), "lid.176.bin")


class TranslationService:
    """Service for handling language detection and translation using IndicTrans2"""
    
    def __init__(self):
        self.en_indic_model = None
        self.en_indic_tokenizer = None
        self.indic_en_model = None
        self.indic_en_tokenizer = None
        self.language_detector = None
        self.device = "cuda" if torch.cuda.is_available() and os.getenv("DEVICE", "cuda") == "cuda" else "cpu"
        self.model_dir = os.getenv("MODEL_PATH", "models/indictrans2")
        self.model_loaded = False
        self.model_type = os.getenv("MODEL_TYPE", "mock")  # Read here instead
        
        # Try to import transformers when needed
        self.transformers_available = False
        try:
            import transformers
            self.transformers_available = True
        except ImportError:
            logger.warning("Transformers not available, will use mock mode")
        
        # Language code mappings for IndicTrans2 (ISO to Flores codes)
        self.lang_code_map = {
            "en": "eng_Latn",
            "hi": "hin_Deva",
            "bn": "ben_Beng", 
            "gu": "guj_Gujr",
            "kn": "kan_Knda",
            "ml": "mal_Mlym",
            "mr": "mar_Deva",
            "or": "ory_Orya", 
            "pa": "pan_Guru",
            "ta": "tam_Taml",
            "te": "tel_Telu",
            "ur": "urd_Arab",
            "as": "asm_Beng",
            "ne": "npi_Deva",
            "sa": "san_Deva"
        }
        
        # Language name to code mapping
        self.lang_name_to_code = {
            "English": "en",
            "Hindi": "hi",
            "Bengali": "bn",
            "Gujarati": "gu", 
            "Kannada": "kn",
            "Malayalam": "ml",
            "Marathi": "mr",
            "Odia": "or",
            "Punjabi": "pa",
            "Tamil": "ta",
            "Telugu": "te",
            "Urdu": "ur",
            "Assamese": "as",
            "Nepali": "ne",
            "Sanskrit": "sa"
        }
        
        # Reverse mapping for response
        self.reverse_lang_map = {v: k for k, v in self.lang_code_map.items()}
    
    async def load_models(self):
        """Load IndicTrans2 model and language detector based on MODEL_TYPE"""
        if self.model_loaded:
            return
            
        logger.info(f"Starting model loading process (Mode: {self.model_type}, Device: {self.device})...")
        
        if self.model_type == "indictrans2" and self.transformers_available:
            try:
                await self._load_language_detector()
                await self._load_indictrans2_model()
                self.model_loaded = True
                logger.info("✅ Real IndicTrans2 models loaded successfully!")
            except Exception as e:
                logger.error(f"❌ Failed to load real models: {str(e)}")
                logger.warning("Falling back to mock implementation.")
                self._use_mock_implementation()
        else:
            self._use_mock_implementation()
            
    def _use_mock_implementation(self):
        """Sets up the service to use mock implementations."""
        logger.info("Using mock implementation for development.")
        self.language_detector = "mock"
        self.en_indic_model = "mock"
        self.en_indic_tokenizer = "mock"
        self.indic_en_model = "mock"
        self.indic_en_tokenizer = "mock"
        self.model_loaded = True

    async def _download_fasttext_model(self):
        """Downloads the FastText model if it doesn't exist."""
        if not os.path.exists(FASTTEXT_MODEL_PATH):
            logger.info(f"Downloading FastText language detection model from {FASTTEXT_MODEL_URL}...")
            try:
                response = requests.get(FASTTEXT_MODEL_URL, stream=True)
                response.raise_for_status()
                with open(FASTTEXT_MODEL_PATH, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                logger.info(f"✅ FastText model downloaded to {FASTTEXT_MODEL_PATH}")
            except Exception as e:
                logger.error(f"❌ Failed to download FastText model: {e}")
                raise
    
    async def _load_language_detector(self):
        """Load FastText language detection model"""
        if not FASTTEXT_AVAILABLE:
            logger.warning("FastText not available, falling back to rule-based detection")
            self.language_detector = "rule_based"
            return
            
        await self._download_fasttext_model()
        try:
            logger.info("Loading FastText language detection model...")
            self.language_detector = fasttext.load_model(FASTTEXT_MODEL_PATH)
            logger.info("✅ FastText model loaded.")
        except Exception as e:
            logger.error(f"❌ Failed to load FastText model: {str(e)}")
            logger.warning("Falling back to rule-based detection")
            self.language_detector = "rule_based"

    async def _load_indictrans2_model(self):
        """Load IndicTrans2 translation models using Hugging Face transformers"""
        try:
            # Import transformers here to avoid import-time errors
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
            import warnings
            warnings.filterwarnings("ignore", category=UserWarning)
            
            logger.info(f"Loading IndicTrans2 models from: {self.model_dir}...")
            
            # Use Hugging Face model hub directly instead of local files
            logger.info("Loading EN→Indic model from Hugging Face...")
            try:
                self.en_indic_tokenizer = AutoTokenizer.from_pretrained(
                    "ai4bharat/indictrans2-en-indic-1B", 
                    trust_remote_code=True
                )
                self.en_indic_model = AutoModelForSeq2SeqLM.from_pretrained(
                    "ai4bharat/indictrans2-en-indic-1B", 
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                )
                self.en_indic_model.to(self.device)
                self.en_indic_model.eval()
                logger.info("✅ EN→Indic model loaded successfully")
            except Exception as e:
                logger.error(f"❌ Failed to load EN→Indic model: {e}")
                raise
            
            logger.info("Loading Indic→EN model from Hugging Face...")
            try:
                self.indic_en_tokenizer = AutoTokenizer.from_pretrained(
                    "ai4bharat/indictrans2-indic-en-1B", 
                    trust_remote_code=True
                )
                self.indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
                    "ai4bharat/indictrans2-indic-en-1B", 
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                )
                self.indic_en_model.to(self.device)
                self.indic_en_model.eval()
                logger.info("✅ Indic→EN model loaded successfully")
            except Exception as e:
                logger.error(f"❌ Failed to load Indic→EN model: {e}")
                raise
            
            logger.info("✅ IndicTrans2 models loaded successfully.")
        except Exception as e:
            logger.error(f"❌ Failed to load IndicTrans2 models: {str(e)}")
            logger.error("Make sure you have:")
            logger.error("1. Downloaded the IndicTrans2 model files")
            logger.error("2. Set the correct MODEL_PATH in .env")
            logger.error("3. Installed all required dependencies")
            raise
    
    async def detect_language(self, text: str) -> Dict[str, Any]:
        """
        Detect language of input text
        """
        await self.load_models()

        if self.model_type == "mock" or not FASTTEXT_AVAILABLE or self.language_detector == "rule_based":
            detected_lang = self._rule_based_language_detection(text)
            return {
                "language": detected_lang,
                "confidence": 0.85,
                "language_name": SUPPORTED_LANGUAGES.get(detected_lang, detected_lang)
            }

        try:
            # Use FastText for language detection
            predictions = self.language_detector.predict(text.replace('\n', ' '), k=1)
            detected_lang_code = predictions[0][0].replace('__label__', '')
            confidence = float(predictions[1][0])
            
            # Map to our supported languages
            lang_mapping = {
                'hi': 'hi', 'bn': 'bn', 'gu': 'gu', 'kn': 'kn', 'ml': 'ml',
                'mr': 'mr', 'or': 'or', 'pa': 'pa', 'ta': 'ta', 'te': 'te',
                'ur': 'ur', 'as': 'as', 'ne': 'ne', 'sa': 'sa', 'en': 'en'
            }
            
            detected_lang = lang_mapping.get(detected_lang_code, 'en')
            
            return {
                "language": detected_lang,
                "confidence": confidence,
                "language_name": SUPPORTED_LANGUAGES.get(detected_lang, detected_lang)
            }
            
        except Exception as e:
            logger.error(f"Language detection failed: {str(e)}")
            # Fallback to rule-based detection
            detected_lang = self._rule_based_language_detection(text)
            return {
                "language": detected_lang,
                "confidence": 0.50,
                "language_name": SUPPORTED_LANGUAGES.get(detected_lang, detected_lang)
            }
    
    def _rule_based_language_detection(self, text: str) -> str:
        """Simple rule-based language detection as fallback"""
        text_lower = text.lower()
        
        # Check for English indicators
        english_words = ['the', 'and', 'is', 'in', 'to', 'of', 'for', 'with', 'on', 'at']
        if any(word in text_lower for word in english_words):
            return 'en'
        
        # Check for Hindi indicators (Devanagari script)
        if any('\u0900' <= char <= '\u097F' for char in text):
            return 'hi'
        
        # Check for Bengali indicators
        if any('\u0980' <= char <= '\u09FF' for char in text):
            return 'bn'
        
        # Check for Tamil indicators
        if any('\u0B80' <= char <= '\u0BFF' for char in text):
            return 'ta'
        
        # Check for Telugu indicators
        if any('\u0C00' <= char <= '\u0C7F' for char in text):
            return 'te'
        
        # Default to English
        return 'en'
    
    async def translate(self, text: str, source_lang: str, target_lang: str) -> Dict[str, Any]:
        """
        Translate text from source language to target language using IndicTrans2
        """
        await self.load_models()
        
        if self.model_type == "mock" or self.en_indic_model == "mock":
            return self._mock_translate(text, source_lang, target_lang)
        
        try:
            # Validate language codes first
            valid_codes = set(self.lang_code_map.keys()) | set(self.lang_name_to_code.keys())
            
            if source_lang not in valid_codes:
                logger.error(f"Invalid source language: {source_lang}")
                return self._mock_translate(text, source_lang, target_lang)
                
            if target_lang not in valid_codes:
                logger.error(f"Invalid target language: {target_lang}")
                return self._mock_translate(text, source_lang, target_lang)
            
            # Convert language names to codes if needed
            src_lang_code = self.lang_name_to_code.get(source_lang, source_lang)
            tgt_lang_code = self.lang_name_to_code.get(target_lang, target_lang)
            
            # Validate converted codes
            if src_lang_code not in self.lang_code_map:
                logger.error(f"Invalid source language code after conversion: {src_lang_code}")
                return self._mock_translate(text, source_lang, target_lang)
                
            if tgt_lang_code not in self.lang_code_map:
                logger.error(f"Invalid target language code after conversion: {tgt_lang_code}")
                return self._mock_translate(text, source_lang, target_lang)
            
            logger.info(f"Converting {source_lang} -> {src_lang_code}, {target_lang} -> {tgt_lang_code}")
            
            # Map language codes to IndicTrans2 format
            src_code = self.lang_code_map.get(src_lang_code, src_lang_code)
            tgt_code = self.lang_code_map.get(tgt_lang_code, tgt_lang_code)
            
            logger.info(f"Using IndicTrans2 codes: {src_code} -> {tgt_code}")
            
            # Choose the right model and tokenizer based on direction
            if src_lang_code == "en" and tgt_lang_code != "en":
                # English to Indic
                model = self.en_indic_model
                tokenizer = self.en_indic_tokenizer
                # Use the correct IndicTrans2 format: just the text without language prefixes
                input_text = text.strip()
                logger.info(f"EN->Indic translation: '{input_text}' using {src_code}->{tgt_code}")
            elif src_lang_code != "en" and tgt_lang_code == "en":
                # Indic to English
                model = self.indic_en_model
                tokenizer = self.indic_en_tokenizer
                # Use the correct IndicTrans2 format: just the text without language prefixes
                input_text = text.strip()
                logger.info(f"Indic->EN translation: '{input_text}' using {src_code}->{tgt_code}")
            else:
                # For Indic to Indic, use English as pivot (not ideal but works)
                if src_lang_code != "en":
                    # First translate to English
                    intermediate_result = await self.translate(text, src_lang_code, "en")
                    intermediate_text = intermediate_result["translated_text"]
                    # Then translate from English to target
                    return await self.translate(intermediate_text, "en", tgt_lang_code)
                else:
                    # Same language, return as is
                    return {
                        "translated_text": text,
                        "source_language": source_lang,
                        "target_language": target_lang,
                        "model": "IndicTrans2 (No translation needed)",
                        "confidence": 1.0
                    }
            
            # Tokenize and translate with basic format
            try:
                inputs = tokenizer(
                    input_text, 
                    return_tensors="pt", 
                    padding=True, 
                    truncation=True, 
                    max_length=512
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs, 
                        max_length=512, 
                        num_beams=5, 
                        do_sample=False
                    )
            except Exception as tokenizer_error:
                logger.error(f"Tokenization/Generation error: {str(tokenizer_error)}")
                return self._mock_translate(text, source_lang, target_lang)
            
            translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            return {
                "translated_text": translated_text,
                "source_language": source_lang,
                "target_language": target_lang,
                "model": "IndicTrans2",
                "confidence": 0.92
            }
            
        except Exception as e:
            logger.error(f"Translation failed: {str(e)}")
            # Fallback to mock translation
            return self._mock_translate(text, source_lang, target_lang)
    
    def _mock_translate(self, text: str, source_lang: str, target_lang: str) -> Dict[str, Any]:
        """Mock translation for development and fallback"""
        mock_translations = {
            ("en", "hi"): "नमस्ते, यह एक परीक्षण अनुवाद है।",
            ("hi", "en"): "Hello, this is a test translation.",
            ("en", "bn"): "হ্যালো, এটি একটি পরীক্ষা অনুবাদ।",
            ("bn", "en"): "Hello, this is a test translation.",
            ("en", "ta"): "வணக்கம், இது ஒரு சோதனை மொழிபெயர்ப்பு.",
            ("ta", "en"): "Hello, this is a test translation."
        }
        
        translated_text = mock_translations.get(
            (source_lang, target_lang), 
            f"[MOCK] Translated from {source_lang} to {target_lang}: {text}"
        )
        
        return {
            "translated_text": translated_text,
            "source_language": source_lang,
            "target_language": target_lang,
            "model": "Mock (Development)",
            "confidence": 0.75
        }

    async def batch_translate(self, texts: List[str], source_lang: str, target_lang: str) -> List[Dict[str, Any]]:
        """
        Translate multiple texts in batch for efficiency
        """
        await self.load_models()
        
        if self.model_type == "mock" or self.en_indic_model == "mock":
            return [self._mock_translate(text, source_lang, target_lang) for text in texts]
        
        try:
            results = []
            for text in texts:
                result = await self.translate(text, source_lang, target_lang)
                result["original_text"] = text
                results.append(result)
            
            return results
            
        except Exception as e:
            logger.error(f"Batch translation failed: {str(e)}")
            # Fallback to individual mock translations
            return [self._mock_translate(text, source_lang, target_lang) for text in texts]

    def get_supported_languages(self) -> Dict[str, str]:
        """Get supported languages mapping"""
        return SUPPORTED_LANGUAGES
    
    def get_language_codes(self) -> List[str]:
        """Get list of supported language codes"""
        return list(self.lang_code_map.keys())
    
    def validate_language_code(self, lang_code: str) -> bool:
        """Validate if a language code is supported"""
        valid_codes = set(self.lang_code_map.keys()) | set(self.lang_name_to_code.keys())
        return lang_code in valid_codes

    def is_translation_supported(self, source_lang: str, target_lang: str) -> bool:
        """Check if translation between two languages is supported"""
        return source_lang in SUPPORTED_LANGUAGES and target_lang in SUPPORTED_LANGUAGES

# Global service instance
translation_service = TranslationService()

async def get_translation_service() -> TranslationService:
    """Dependency injection for FastAPI"""
    return translation_service