# Defer torch import to avoid CUDA initialization issues # torch will be imported when needed in the _load_model method from typing import List, Dict, Union, Optional import logging from PIL import Image import requests import os import tempfile # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GoogleTranslateGemma: """ Google Translate Gemma model wrapper for text and image translation. This class provides an interface to the Google TranslateGemma model for: - Text translation between languages - Text extraction and translation from images """ def __init__(self, model_id: str = "google/translategemma-12b-it"): """ Initialize the Google Translate Gemma model. Args: model_id (str): The model identifier from Hugging Face """ self.model_id = model_id self.model = None self.processor = None self.device = None # Will be set when torch is imported self._load_model() def _load_model(self): """Load the model using direct approach.""" try: # Import torch here to avoid CUDA initialization issues import torch from transformers import AutoModelForImageTextToText, AutoProcessor logger.info(f"Loading model: {self.model_id}") self.processor = AutoProcessor.from_pretrained(self.model_id) self.model = AutoModelForImageTextToText.from_pretrained( self.model_id, device_map="auto" ) self.device = self.model.device logger.info(f"Model loaded successfully on device: {self.device}") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise def translate_text( self, text: str, source_lang: str, target_lang: str, max_new_tokens: int = 200 ) -> str: """ Translate text from source language to target language. Args: text (str): The text to translate source_lang (str): Source language code (e.g., 'cs' for Czech) target_lang (str): Target language code (e.g., 'de-DE' for German) max_new_tokens (int): Maximum number of tokens to generate Returns: str: The translated text """ messages = [ { "role": "user", "content": [ { "type": "text", "source_lang_code": source_lang, "target_lang_code": target_lang, "text": text, } ], } ] try: # Import torch here if not already imported import torch # Use direct model approach inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(self.device, dtype=torch.bfloat16) input_len = len(inputs['input_ids'][0]) with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) return decoded except Exception as e: logger.error(f"Translation failed: {str(e)}") raise def translate_image( self, image_input: Union[str, Image.Image], source_lang: str, target_lang: str, max_new_tokens: int = 200 ) -> str: """ Extract text from an image and translate it to the target language. Args: image_input (Union[str, Image.Image]): URL or PIL Image object containing text source_lang (str): Source language code (e.g., 'cs' for Czech) target_lang (str): Target language code (e.g., 'de-DE' for German) max_new_tokens (int): Maximum number of tokens to generate Returns: str: The extracted and translated text """ # Handle local image files if isinstance(image_input, str) and os.path.exists(image_input): # It's a local file path image = Image.open(image_input) messages = [ { "role": "user", "content": [ { "type": "image", "source_lang_code": source_lang, "target_lang_code": target_lang, "image": image, }, ], } ] return self._translate_with_messages(messages, max_new_tokens) # Handle PIL Image objects elif isinstance(image_input, Image.Image): messages = [ { "role": "user", "content": [ { "type": "image", "source_lang_code": source_lang, "target_lang_code": target_lang, "image": image_input, }, ], } ] return self._translate_with_messages(messages, max_new_tokens) # Handle URLs else: messages = [ { "role": "user", "content": [ { "type": "image", "source_lang_code": source_lang, "target_lang_code": target_lang, "url": image_input, }, ], } ] return self._translate_with_messages(messages, max_new_tokens) def _translate_with_messages(self, messages: List[Dict], max_new_tokens: int = 200) -> str: """ Helper method to translate using messages with direct model. Args: messages (List[Dict]): Formatted messages for the model max_new_tokens (int): Maximum number of tokens to generate Returns: str: The translated text """ try: # Import torch here if not already imported import torch # Use direct model approach inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(self.device, dtype=torch.bfloat16) input_len = len(inputs['input_ids'][0]) with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens) generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) return decoded except Exception as e: logger.error(f"Translation failed: {str(e)}") raise # Example usage and testing functions def test_text_translation(): """Test text translation functionality.""" print("Testing text translation...") translator = GoogleTranslateGemma() # Example: Czech to German source_text = "V nejhorším případě i k prasknutí čočky." source_lang = "cs" target_lang = "de-DE" try: translated = translator.translate_text( text=source_text, source_lang=source_lang, target_lang=target_lang ) print(f"Source ({source_lang}): {source_text}") print(f"Target ({target_lang}): {translated}") print("-" * 50) except Exception as e: print(f"Text translation test failed: {str(e)}") def test_image_translation(): """Test image translation functionality.""" print("Testing image translation...") translator = GoogleTranslateGemma() # Example: Czech traffic sign to German image_url = "https://c7.alamy.com/comp/2YAX36N/traffic-signs-in-czech-republic-pedestrian-zone-2YAX36N.jpg" source_lang = "cs" target_lang = "de-DE" try: translated = translator.translate_image( image_url=image_url, source_lang=source_lang, target_lang=target_lang ) print(f"Image URL: {image_url}") print(f"Source ({source_lang}): [Text extracted from image]") print(f"Target ({target_lang}): {translated}") print("-" * 50) except Exception as e: print(f"Image translation test failed: {str(e)}") def main(): """Main function to run example translations.""" print("Google Translate Gemma Module") print("=" * 50) # Run tests test_text_translation() test_image_translation() print("Example completed!") if __name__ == "__main__": main()