File size: 9,854 Bytes
621ec47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f1270
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

# Defer torch import to avoid CUDA initialization issues
# torch will be imported when needed in the _load_model method
from typing import List, Dict, Union, Optional
import logging
from PIL import Image
import requests
import os
import tempfile

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class GoogleTranslateGemma:
    """

    Google Translate Gemma model wrapper for text and image translation.

    

    This class provides an interface to the Google TranslateGemma model for:

    - Text translation between languages

    - Text extraction and translation from images

    """
    
    def __init__(self, model_id: str = "google/translategemma-12b-it"):
        """

        Initialize the Google Translate Gemma model.

        

        Args:

            model_id (str): The model identifier from Hugging Face

        """
        self.model_id = model_id
        self.model = None
        self.processor = None
        self.device = None  # Will be set when torch is imported
        self._load_model()
    
    def _load_model(self):
        """Load the model using direct approach."""
        try:
            # Import torch here to avoid CUDA initialization issues
            import torch
            from transformers import AutoModelForImageTextToText, AutoProcessor
            
            logger.info(f"Loading model: {self.model_id}")
            self.processor = AutoProcessor.from_pretrained(self.model_id)
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_id, 
                device_map="auto"
            )
            self.device = self.model.device
            logger.info(f"Model loaded successfully on device: {self.device}")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise
    
    def translate_text(

        self, 

        text: str, 

        source_lang: str, 

        target_lang: str,

        max_new_tokens: int = 200

    ) -> str:
        """

        Translate text from source language to target language.

        

        Args:

            text (str): The text to translate

            source_lang (str): Source language code (e.g., 'cs' for Czech)

            target_lang (str): Target language code (e.g., 'de-DE' for German)

            max_new_tokens (int): Maximum number of tokens to generate

            

        Returns:

            str: The translated text

        """
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "source_lang_code": source_lang,
                        "target_lang_code": target_lang,
                        "text": text,
                    }
                ],
            }
        ]
        
        try:
            # Import torch here if not already imported
            import torch
            
            # Use direct model approach
            inputs = self.processor.apply_chat_template(
                messages, 
                tokenize=True, 
                add_generation_prompt=True, 
                return_dict=True, 
                return_tensors="pt"
            ).to(self.device, dtype=torch.bfloat16)
            
            input_len = len(inputs['input_ids'][0])
            
            with torch.inference_mode():
                generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
            
            generation = generation[0][input_len:]
            decoded = self.processor.decode(generation, skip_special_tokens=True)
            return decoded
        except Exception as e:
            logger.error(f"Translation failed: {str(e)}")
            raise
    
    def translate_image(

        self, 

        image_input: Union[str, Image.Image], 

        source_lang: str, 

        target_lang: str,

        max_new_tokens: int = 200

    ) -> str:
        """

        Extract text from an image and translate it to the target language.

        

        Args:

            image_input (Union[str, Image.Image]): URL or PIL Image object containing text

            source_lang (str): Source language code (e.g., 'cs' for Czech)

            target_lang (str): Target language code (e.g., 'de-DE' for German)

            max_new_tokens (int): Maximum number of tokens to generate

            

        Returns:

            str: The extracted and translated text

        """
        # Handle local image files
        if isinstance(image_input, str) and os.path.exists(image_input):
            # It's a local file path
            image = Image.open(image_input)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source_lang_code": source_lang,
                            "target_lang_code": target_lang,
                            "image": image,
                        },
                    ],
                }
            ]
            return self._translate_with_messages(messages, max_new_tokens)
        
        # Handle PIL Image objects
        elif isinstance(image_input, Image.Image):
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source_lang_code": source_lang,
                            "target_lang_code": target_lang,
                            "image": image_input,
                        },
                    ],
                }
            ]
            return self._translate_with_messages(messages, max_new_tokens)
        
        # Handle URLs
        else:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source_lang_code": source_lang,
                            "target_lang_code": target_lang,
                            "url": image_input,
                        },
                    ],
                }
            ]
            return self._translate_with_messages(messages, max_new_tokens)
    
    def _translate_with_messages(self, messages: List[Dict], max_new_tokens: int = 200) -> str:
        """

        Helper method to translate using messages with direct model.

        

        Args:

            messages (List[Dict]): Formatted messages for the model

            max_new_tokens (int): Maximum number of tokens to generate

            

        Returns:

            str: The translated text

        """
        try:
            # Import torch here if not already imported
            import torch
            
            # Use direct model approach
            inputs = self.processor.apply_chat_template(
                messages, 
                tokenize=True, 
                add_generation_prompt=True, 
                return_dict=True, 
                return_tensors="pt"
            ).to(self.device, dtype=torch.bfloat16)
            
            input_len = len(inputs['input_ids'][0])
            
            with torch.inference_mode():
                generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
            
            generation = generation[0][input_len:]
            decoded = self.processor.decode(generation, skip_special_tokens=True)
            return decoded
        except Exception as e:
            logger.error(f"Translation failed: {str(e)}")
            raise
    
    


# Example usage and testing functions
def test_text_translation():
    """Test text translation functionality."""
    print("Testing text translation...")
    
    translator = GoogleTranslateGemma()
    
    # Example: Czech to German
    source_text = "V nejhorším případě i k prasknutí čočky."
    source_lang = "cs"
    target_lang = "de-DE"
    
    try:
        translated = translator.translate_text(
            text=source_text,
            source_lang=source_lang,
            target_lang=target_lang
        )
        print(f"Source ({source_lang}): {source_text}")
        print(f"Target ({target_lang}): {translated}")
        print("-" * 50)
    except Exception as e:
        print(f"Text translation test failed: {str(e)}")


def test_image_translation():
    """Test image translation functionality."""
    print("Testing image translation...")
    
    translator = GoogleTranslateGemma()
    
    # Example: Czech traffic sign to German
    image_url = "https://c7.alamy.com/comp/2YAX36N/traffic-signs-in-czech-republic-pedestrian-zone-2YAX36N.jpg"
    source_lang = "cs"
    target_lang = "de-DE"
    
    try:
        translated = translator.translate_image(
            image_url=image_url,
            source_lang=source_lang,
            target_lang=target_lang
        )
        print(f"Image URL: {image_url}")
        print(f"Source ({source_lang}): [Text extracted from image]")
        print(f"Target ({target_lang}): {translated}")
        print("-" * 50)
    except Exception as e:
        print(f"Image translation test failed: {str(e)}")


def main():
    """Main function to run example translations."""
    print("Google Translate Gemma Module")
    print("=" * 50)
    
    # Run tests
    test_text_translation()
    test_image_translation()
    
    print("Example completed!")


if __name__ == "__main__":
    main()