File size: 7,864 Bytes
d413c3a
 
e86fdec
 
 
 
 
 
 
d413c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e86fdec
 
 
d413c3a
 
 
 
02dd2f8
 
 
 
 
d413c3a
e86fdec
d413c3a
 
 
21682e0
 
 
e86fdec
02dd2f8
 
 
 
e86fdec
 
 
 
d413c3a
 
e86fdec
d413c3a
 
 
e86fdec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d413c3a
e86fdec
d413c3a
 
 
 
 
 
 
 
 
e86fdec
d413c3a
e86fdec
d413c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
21682e0
d413c3a
e86fdec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Department 3 - Translator
Primary  : NLLB-200-distilled-1.3B (Meta)
Fallback : deep-translator (Google Translate)

βœ… UPGRADED:
  - Text chunking for long transcripts (fixes repetition bug)
  - Splits by sentence, translates in 400-token chunks
  - Rejoins cleanly into full translation
"""

import time
import logging

logger = logging.getLogger(__name__)

NLLB_CODES = {
    "en": "eng_Latn",
    "te": "tel_Telu",
    "hi": "hin_Deva",
    "ta": "tam_Taml",
    "kn": "kan_Knda",
    "es": "spa_Latn",
    "fr": "fra_Latn",
    "de": "deu_Latn",
    "ja": "jpn_Jpan",
    "zh": "zho_Hans",
    "ar": "arb_Arab",
    "pt": "por_Latn",
    "ru": "rus_Cyrl",
}

MODEL_ID      = "facebook/nllb-200-distilled-1.3B"
MAX_LENGTH    = 512
CHUNK_WORDS   = 80  # ~400 tokens, safe for NLLB


class Translator:
    def __init__(self):
        self._pipeline    = None
        self._tokenizer   = None
        self._model       = None
        self._nllb_loaded = False
        print("[Translator] Ready (NLLB loads on first use)")

    # ── Public ───────────────────────────────────────────────────────
    def translate(self, text: str, src_lang: str, tgt_lang: str):
        if not text or not text.strip():
            return "", "skipped (empty)"
        if src_lang == tgt_lang:
            return text, "skipped (same language)"

        # Load NLLB on first use
        if not self._nllb_loaded:
            self._init_nllb()
            self._nllb_loaded = True

        # Split long text into chunks
        chunks = self._split_into_chunks(text, CHUNK_WORDS)
        print(f"[Translator] Translating {len(chunks)} chunks ({len(text)} chars)")

        if self._pipeline is not None or self._model is not None:
            try:
                return self._translate_chunks_nllb(chunks, src_lang, tgt_lang)
            except Exception as e:
                logger.warning(f"[Translator] NLLB failed ({e}), trying Google...")

        return self._translate_chunks_google(chunks, src_lang, tgt_lang)

    # ── Chunking ─────────────────────────────────────────────────────
    def _split_into_chunks(self, text: str, max_words: int):
        """Split text into sentence-aware chunks of max_words words."""
        # Split by sentence endings
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text.strip())

        chunks  = []
        current = []
        count   = 0

        for sentence in sentences:
            words = sentence.split()
            if count + len(words) > max_words and current:
                chunks.append(" ".join(current))
                current = []
                count   = 0
            current.append(sentence)
            count += len(words)

        if current:
            chunks.append(" ".join(current))

        return chunks

    # ── NLLB chunked translation ──────────────────────────────────────
    def _translate_chunks_nllb(self, chunks, src_lang, tgt_lang):
        t0       = time.time()
        results  = []
        src_code = NLLB_CODES.get(src_lang, "eng_Latn")
        tgt_code = NLLB_CODES.get(tgt_lang, "tel_Telu")

        for i, chunk in enumerate(chunks):
            if not chunk.strip():
                continue
            try:
                if self._pipeline is not None:
                    result = self._pipeline(
                        chunk,
                        src_lang=src_code,
                        tgt_lang=tgt_code,
                        max_length=MAX_LENGTH,
                    )
                    results.append(result[0]["translation_text"])
                else:
                    import torch
                    inputs = self._tokenizer(
                        chunk,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=MAX_LENGTH,
                    )
                    if torch.cuda.is_available():
                        inputs = {k: v.cuda() for k, v in inputs.items()}
                    tgt_lang_id = self._tokenizer.convert_tokens_to_ids(tgt_code)
                    with torch.no_grad():
                        output_ids = self._model.generate(
                            **inputs,
                            forced_bos_token_id=tgt_lang_id,
                            max_length=MAX_LENGTH,
                            num_beams=4,
                            early_stopping=True,
                        )
                    translated = self._tokenizer.batch_decode(
                        output_ids, skip_special_tokens=True)[0]
                    results.append(translated)
            except Exception as e:
                logger.warning(f"[Translator] Chunk {i+1} failed: {e}")
                results.append(chunk)  # fallback: keep original

        translated = " ".join(results)
        elapsed    = time.time() - t0
        logger.info(f"[Translator] NLLB done in {elapsed:.2f}s: {src_code}->{tgt_code}")
        print(f"[Translator] βœ… Done in {elapsed:.2f}s ({len(chunks)} chunks)")
        return translated, f"NLLB-200-distilled-1.3B ({len(chunks)} chunks)"

    # ── Google chunked translation ────────────────────────────────────
    def _translate_chunks_google(self, chunks, src_lang, tgt_lang):
        t0 = time.time()
        try:
            from deep_translator import GoogleTranslator
            results = []
            for chunk in chunks:
                if not chunk.strip():
                    continue
                translated = GoogleTranslator(
                    source=src_lang if src_lang != "auto" else "auto",
                    target=tgt_lang,
                ).translate(chunk)
                results.append(translated)
            full = " ".join(results)
            logger.info(f"[Translator] Google done in {time.time()-t0:.2f}s")
            return full, f"Google Translate ({len(chunks)} chunks)"
        except Exception as e:
            logger.error(f"[Translator] Google fallback failed: {e}")
            return f"[Translation failed: {str(e)}]", "error"

    # ── NLLB init ────────────────────────────────────────────────────
    def _init_nllb(self):
        try:
            from transformers import pipeline as hf_pipeline
            self._pipeline = hf_pipeline(
                "translation",
                model=MODEL_ID,
                device_map="auto",
                max_length=MAX_LENGTH,
            )
            print(f"[Translator] βœ… {MODEL_ID} loaded")
        except Exception as e:
            logger.warning(f"[Translator] Pipeline init failed: {e}, trying manual...")
            self._init_nllb_manual()

    def _init_nllb_manual(self):
        try:
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
            import torch
            self._tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
            self._model = AutoModelForSeq2SeqLM.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            )
            if torch.cuda.is_available():
                self._model = self._model.cuda()
            self._model.eval()
            print(f"[Translator] βœ… {MODEL_ID} loaded manually")
        except Exception as e:
            logger.error(f"[Translator] NLLB manual load failed: {e}")
            self._model = None