File size: 15,401 Bytes
2b83054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import json
import copy
import logging
import time
from typing import List, Dict, Optional, Any, Union
from itertools import chain
from tqdm import tqdm
from deep_translator import GoogleTranslator
from dotenv import load_dotenv

# Add this at the top of your translate.py file
# Load environment variables from .env file
load_dotenv()


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Import Groq
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import LLMChain

# Language code constants - simplified for now
ISO_LANGUAGE_CODES = {
    "en": "english",
    "es": "spanish",
    "fr": "french",
    "de": "german",
    "it": "italian",
    "pt": "portuguese",
    "ru": "russian",
    "zh": "chinese",
    "ja": "japanese",
    "ko": "korean",
    "hi": "hindi",
    "ar": "arabic",
}

def fix_language_code(language_code: Optional[str]) -> str:
    """Convert language code to format compatible with translator."""
    if not language_code:
        return "auto"
    
    # Clean up language code (remove region specifiers)
    language_code = language_code.lower().split('-')[0]
    
    # Return the cleaned code if it's in our list, otherwise default to auto
    return language_code if language_code in ISO_LANGUAGE_CODES else "auto"

def translate_iterative(segments: List[Dict[str, Any]], 
                        target_lang: str, 
                        source_lang: Optional[str] = None) -> List[Dict[str, Any]]:
    """
    Translate text segments individually to the specified language.

    Args:
        segments: List of dictionaries with 'text' key containing the text to translate
        target_lang: Target language code
        source_lang: Source language code (defaults to auto-detect)

    Returns:
        List of segments with translated text
    """
    segments_copy = copy.deepcopy(segments)
    source = fix_language_code(source_lang)
    target = fix_language_code(target_lang)
    
    logger.info(f"Translating {len(segments)} segments from {source} to {target} (iterative)")
    translator = GoogleTranslator(source=source, target=target)

    for i, segment in enumerate(tqdm(segments_copy, desc="Translating")):
        text = segment["text"].strip()
        try:
            translated_text = translator.translate(text)
            segments_copy[i]["text"] = translated_text
        except Exception as error:
            logger.error(f"Error translating segment {i}: {error}")
            # Keep original text if translation fails
            segments_copy[i]["text"] = text

    return segments_copy

def verify_translation(original_segments: List[Dict[str, Any]],
                      segments_copy: List[Dict[str, Any]],
                      translated_lines: List[str],
                      target_lang: str,
                      source_lang: Optional[str] = None) -> List[Dict[str, Any]]:
    """
    Verify translation integrity and assign translated text to segments.
    Falls back to iterative translation if segment counts don't match.
    """
    if len(original_segments) == len(translated_lines):
        for i in range(len(segments_copy)):
            segments_copy[i]["text"] = translated_lines[i].replace("\t", " ").replace("\n", " ").strip()
        return segments_copy
    else:
        logger.error(
            f"Translation failed: segment count mismatch. Original: {len(original_segments)}, "
            f"Translated: {len(translated_lines)}. Switching to iterative translation."
        )
        return translate_iterative(original_segments, target_lang, source_lang)

def translate_batch(segments: List[Dict[str, Any]], 
                   target_lang: str, 
                   chunk_size: int = 4000,
                   source_lang: Optional[str] = None) -> List[Dict[str, Any]]:
    """
    Translate a batch of text segments in chunks to respect API limits.

    Args:
        segments: List of dictionaries with 'text' key
        target_lang: Target language code
        chunk_size: Maximum character count per chunk (default: 4000)
        source_lang: Source language code (defaults to auto-detect)

    Returns:
        List of segments with translated text
    """
    segments_copy = copy.deepcopy(segments)
    source = fix_language_code(source_lang)
    target = fix_language_code(target_lang)
    
    logger.info(f"Translating {len(segments)} segments from {source} to {target} (batch)")

    # Extract text from segments
    text_lines = [segment["text"].strip() for segment in segments]
    
    # Create chunks respecting character limit
    text_chunks = []
    current_chunk = ""
    chunk_segments = []
    segment_tracking = []
    
    for line in text_lines:
        line = " " if not line else line
        if (len(current_chunk) + len(line) + 7) <= chunk_size:  # 7 for separator
            if current_chunk:
                current_chunk += " ||||| "
            current_chunk += line
            chunk_segments.append(line)
        else:
            text_chunks.append(current_chunk)
            segment_tracking.append(chunk_segments)
            current_chunk = line
            chunk_segments = [line]
    
    if current_chunk:
        text_chunks.append(current_chunk)
        segment_tracking.append(chunk_segments)

    # Translate chunks
    translator = GoogleTranslator(source=source, target=target)
    translated_segments = []
    progress_bar = tqdm(total=len(segments), desc="Translating")
    
    try:
        for chunk_text, chunk_segments in zip(text_chunks, segment_tracking):
            translated_chunk = translator.translate(chunk_text.strip())
            split_translations = translated_chunk.split("|||||")
            
            # Verify chunk integrity
            if len(split_translations) == len(chunk_segments):
                progress_bar.update(len(split_translations))
                translated_segments.extend([t.strip() for t in split_translations])
            else:
                logger.warning(
                    f"Chunk translation mismatch. Expected {len(chunk_segments)}, "
                    f"got {len(split_translations)}. Translating segment by segment."
                )
                for segment in chunk_segments:
                    translated_text = translator.translate(segment.strip())
                    translated_segments.append(translated_text.strip())
                    progress_bar.update(1)
        
        progress_bar.close()
        
        # Verify and return
        return verify_translation(segments, segments_copy, translated_segments, target_lang, source_lang)
    
    except Exception as error:
        progress_bar.close()
        logger.error(f"Batch translation failed: {error}")
        return translate_iterative(segments, target_lang, source_lang)

def translate_with_groq(segments: List[Dict[str, Any]],
                       target_lang: str,
                       model_name: str = "llama-3.3-70b-versatile",
                       source_lang: Optional[str] = None,
                       batch_size: int = 10) -> List[Dict[str, Any]]:
    """
    Translate text segments using Groq API.
    
    Args:
        segments: List of dictionaries with 'text' key
        target_lang: Target language code
        model_name: Groq model to use (default: "llama-3.3-70b-versatile")
        source_lang: Source language code (optional)
        batch_size: Number of segments to process in each API call
        
    Returns:
        List of segments with translated text
    """
    segments_copy = copy.deepcopy(segments)
    
    # Get language names instead of codes for clarity in prompting
    target_language = ISO_LANGUAGE_CODES.get(fix_language_code(target_lang), "the target language")
    source_language = "auto-detected language"
    if source_lang:
        source_language = ISO_LANGUAGE_CODES.get(fix_language_code(source_lang), "the source language")
    
    logger.info(f"Translating {len(segments)} segments from {source_language} to {target_language} using Groq")
    
    # Set up Groq LLM
    llm = ChatGroq(model_name=model_name, temperature=0.2)
    
    # Process segments in batches
    translated_segments = []
    total_batches = (len(segments) + batch_size - 1) // batch_size
    
    for batch_idx in tqdm(range(total_batches), desc="Translating batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(segments))
        batch = segments[start_idx:end_idx]
        
        # Extract text from segments
        batch_texts = [segment["text"].strip() for segment in batch]
        
        # Create numbered text array for the prompt
        numbered_texts = [f"{i+1}. {text}" for i, text in enumerate(batch_texts)]
        batch_content = "\n".join(numbered_texts)
        
        # Create a prompt template for translation
        template = """
        You are a professional translator. Translate the following text segments from {source_language} to {target_language}.
        
        IMPORTANT INSTRUCTIONS:
        1. Preserve the meaning, tone, and style of the original text
        2. Only respond with JSON in the exact format shown below
        3. Each numbered segment should be translated separately
        4. Maintain the original numbering in your response
        5. Translated Segments should be short and concise
        6. the translated segment should be of similar size as input.
        
        Text to translate:
        {text_segments}
        
        The response should be ONLY a JSON array with this exact structure:
        [
          "translated segment 1",
          "translated segment 2",
          ...
        ]
        """
        
        prompt = ChatPromptTemplate.from_messages([("system", template)])
        
        try:
            # Create a chain and execute
            chain = LLMChain(llm=llm, prompt=prompt)
            response = chain.run(
                source_language=source_language,
                target_language=target_language,
                text_segments=batch_content
            )
            
            # Parse the response
            # First try to find JSON in the response using regex
            import re
            json_match = re.search(r'\[.*\]', response.strip(), re.DOTALL)
            
            if json_match:
                try:
                    translated_texts = json.loads(json_match.group(0))
                except:
                    # If regex json extraction fails, try direct parsing
                    translated_texts = json.loads(response.strip())
            else:
                # If no JSON array found, try to parse directly
                translated_texts = json.loads(response.strip())
            
            # Verify correct count
            if len(translated_texts) != len(batch):
                logger.warning(
                    f"Translation count mismatch. Expected {len(batch)}, "
                    f"got {len(translated_texts)}. Falling back to Google Translate for this batch."
                )
                # Fall back to Google for this batch
                fallback_translations = translate_iterative(batch, target_lang, source_lang)
                translated_texts = [segment["text"] for segment in fallback_translations]
            
            # Add translations to the result
            translated_segments.extend(translated_texts)
            
            # Avoid hitting rate limits
            time.sleep(0.5)
            
        except Exception as error:
            logger.error(f"Groq translation error for batch {batch_idx+1}/{total_batches}: {error}")
            logger.warning("Falling back to Google Translate for this batch")
            
            # Fall back to Google for this batch
            fallback_translations = translate_iterative(batch, target_lang, source_lang)
            batch_translations = [segment["text"] for segment in fallback_translations]
            translated_segments.extend(batch_translations)
    
    # Verify and update segments
    return verify_translation(segments, segments_copy, translated_segments, target_lang, source_lang)

def translate_text(segments: List[Dict[str, Any]],
                  target_lang: str,
                  translation_method: str = "batch",
                  chunk_size: int = 4000,
                  source_lang: Optional[str] = None,
                  groq_model: str = "llama-3.3-70b-versatile",
                  groq_batch_size: int = 10) -> List[Dict[str, Any]]:
    """
    Main translation function that handles different translation methods.
    
    Args:
        segments: List of dictionaries with 'text' key
        target_lang: Target language code
        translation_method: "batch", "iterative", or "groq" (default: "batch")
        chunk_size: Maximum character count per chunk for batch translation
        source_lang: Source language code (defaults to auto-detect)
        groq_model: Model name for Groq translation
        groq_batch_size: Batch size for Groq translation
        
    Returns:
        List of segments with translated text
    """
    if not segments:
        logger.warning("No segments to translate")
        return segments
    
    if translation_method == "batch":
        return translate_batch(segments, target_lang, chunk_size, source_lang)
    elif translation_method == "iterative":
        return translate_iterative(segments, target_lang, source_lang)
    elif translation_method == "groq":
        return translate_with_groq(
            segments, 
            target_lang, 
            model_name=groq_model,
            source_lang=source_lang,
            batch_size=groq_batch_size
        )
    else:
        logger.error(f"Unknown translation method: {translation_method}")
        return translate_batch(segments, target_lang, chunk_size, source_lang)
    
def generate_srt_subtitles(segments, output_file="output.srt"):
    """
    Generate an SRT subtitle file from translated segments.
    
    Args:
        segments: List of dictionaries with 'start', 'end', and 'text' keys
        output_file: Path to the output SRT file
        
    Returns:
        Path to the created SRT file
    """
    logger.info(f"Generating SRT subtitle file: {output_file}")
    
    # Format time as HH:MM:SS,mmm
    def format_time(seconds):
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        seconds = seconds % 60
        milliseconds = int((seconds - int(seconds)) * 1000)
        return f"{hours:02d}:{minutes:02d}:{int(seconds):02d},{milliseconds:03d}"
    
    with open(output_file, "w", encoding="utf-8") as f:
        for i, segment in enumerate(segments, 1):
            # Extract timing information
            start_time = segment.get("start", 0)
            end_time = segment.get("end", 0)
            text = segment.get("text", "").strip()
            
            # Skip empty segments
            if not text:
                continue
                
            # Write subtitle entry
            f.write(f"{i}\n")
            f.write(f"{format_time(start_time)} --> {format_time(end_time)}\n")
            f.write(f"{text}\n\n")
    
    logger.info(f"SRT subtitle file created successfully: {output_file}")
    return output_file