Spaces:
Runtime error
Runtime error
| # Defer torch import to avoid CUDA initialization issues | |
| # torch will be imported when needed in the _load_model method | |
| from typing import List, Dict, Union, Optional | |
| import logging | |
| from PIL import Image | |
| import requests | |
| import os | |
| import tempfile | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class GoogleTranslateGemma: | |
| """ | |
| Google Translate Gemma model wrapper for text and image translation. | |
| This class provides an interface to the Google TranslateGemma model for: | |
| - Text translation between languages | |
| - Text extraction and translation from images | |
| """ | |
| def __init__(self, model_id: str = "google/translategemma-12b-it"): | |
| """ | |
| Initialize the Google Translate Gemma model. | |
| Args: | |
| model_id (str): The model identifier from Hugging Face | |
| """ | |
| self.model_id = model_id | |
| self.model = None | |
| self.processor = None | |
| self.device = None # Will be set when torch is imported | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model using direct approach.""" | |
| try: | |
| # Import torch here to avoid CUDA initialization issues | |
| import torch | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| logger.info(f"Loading model: {self.model_id}") | |
| self.processor = AutoProcessor.from_pretrained(self.model_id) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| self.model_id, | |
| device_map="auto" | |
| ) | |
| self.device = self.model.device | |
| logger.info(f"Model loaded successfully on device: {self.device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| def translate_text( | |
| self, | |
| text: str, | |
| source_lang: str, | |
| target_lang: str, | |
| max_new_tokens: int = 200 | |
| ) -> str: | |
| """ | |
| Translate text from source language to target language. | |
| Args: | |
| text (str): The text to translate | |
| source_lang (str): Source language code (e.g., 'cs' for Czech) | |
| target_lang (str): Target language code (e.g., 'de-DE' for German) | |
| max_new_tokens (int): Maximum number of tokens to generate | |
| Returns: | |
| str: The translated text | |
| """ | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "source_lang_code": source_lang, | |
| "target_lang_code": target_lang, | |
| "text": text, | |
| } | |
| ], | |
| } | |
| ] | |
| try: | |
| # Import torch here if not already imported | |
| import torch | |
| # Use direct model approach | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(self.device, dtype=torch.bfloat16) | |
| input_len = len(inputs['input_ids'][0]) | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| generation = generation[0][input_len:] | |
| decoded = self.processor.decode(generation, skip_special_tokens=True) | |
| return decoded | |
| except Exception as e: | |
| logger.error(f"Translation failed: {str(e)}") | |
| raise | |
| def translate_image( | |
| self, | |
| image_input: Union[str, Image.Image], | |
| source_lang: str, | |
| target_lang: str, | |
| max_new_tokens: int = 200 | |
| ) -> str: | |
| """ | |
| Extract text from an image and translate it to the target language. | |
| Args: | |
| image_input (Union[str, Image.Image]): URL or PIL Image object containing text | |
| source_lang (str): Source language code (e.g., 'cs' for Czech) | |
| target_lang (str): Target language code (e.g., 'de-DE' for German) | |
| max_new_tokens (int): Maximum number of tokens to generate | |
| Returns: | |
| str: The extracted and translated text | |
| """ | |
| # Handle local image files | |
| if isinstance(image_input, str) and os.path.exists(image_input): | |
| # It's a local file path | |
| image = Image.open(image_input) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "source_lang_code": source_lang, | |
| "target_lang_code": target_lang, | |
| "image": image, | |
| }, | |
| ], | |
| } | |
| ] | |
| return self._translate_with_messages(messages, max_new_tokens) | |
| # Handle PIL Image objects | |
| elif isinstance(image_input, Image.Image): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "source_lang_code": source_lang, | |
| "target_lang_code": target_lang, | |
| "image": image_input, | |
| }, | |
| ], | |
| } | |
| ] | |
| return self._translate_with_messages(messages, max_new_tokens) | |
| # Handle URLs | |
| else: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "source_lang_code": source_lang, | |
| "target_lang_code": target_lang, | |
| "url": image_input, | |
| }, | |
| ], | |
| } | |
| ] | |
| return self._translate_with_messages(messages, max_new_tokens) | |
| def _translate_with_messages(self, messages: List[Dict], max_new_tokens: int = 200) -> str: | |
| """ | |
| Helper method to translate using messages with direct model. | |
| Args: | |
| messages (List[Dict]): Formatted messages for the model | |
| max_new_tokens (int): Maximum number of tokens to generate | |
| Returns: | |
| str: The translated text | |
| """ | |
| try: | |
| # Import torch here if not already imported | |
| import torch | |
| # Use direct model approach | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(self.device, dtype=torch.bfloat16) | |
| input_len = len(inputs['input_ids'][0]) | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| generation = generation[0][input_len:] | |
| decoded = self.processor.decode(generation, skip_special_tokens=True) | |
| return decoded | |
| except Exception as e: | |
| logger.error(f"Translation failed: {str(e)}") | |
| raise | |
| # Example usage and testing functions | |
| def test_text_translation(): | |
| """Test text translation functionality.""" | |
| print("Testing text translation...") | |
| translator = GoogleTranslateGemma() | |
| # Example: Czech to German | |
| source_text = "V nejhorším případě i k prasknutí čočky." | |
| source_lang = "cs" | |
| target_lang = "de-DE" | |
| try: | |
| translated = translator.translate_text( | |
| text=source_text, | |
| source_lang=source_lang, | |
| target_lang=target_lang | |
| ) | |
| print(f"Source ({source_lang}): {source_text}") | |
| print(f"Target ({target_lang}): {translated}") | |
| print("-" * 50) | |
| except Exception as e: | |
| print(f"Text translation test failed: {str(e)}") | |
| def test_image_translation(): | |
| """Test image translation functionality.""" | |
| print("Testing image translation...") | |
| translator = GoogleTranslateGemma() | |
| # Example: Czech traffic sign to German | |
| image_url = "https://c7.alamy.com/comp/2YAX36N/traffic-signs-in-czech-republic-pedestrian-zone-2YAX36N.jpg" | |
| source_lang = "cs" | |
| target_lang = "de-DE" | |
| try: | |
| translated = translator.translate_image( | |
| image_url=image_url, | |
| source_lang=source_lang, | |
| target_lang=target_lang | |
| ) | |
| print(f"Image URL: {image_url}") | |
| print(f"Source ({source_lang}): [Text extracted from image]") | |
| print(f"Target ({target_lang}): {translated}") | |
| print("-" * 50) | |
| except Exception as e: | |
| print(f"Image translation test failed: {str(e)}") | |
| def main(): | |
| """Main function to run example translations.""" | |
| print("Google Translate Gemma Module") | |
| print("=" * 50) | |
| # Run tests | |
| test_text_translation() | |
| test_image_translation() | |
| print("Example completed!") | |
| if __name__ == "__main__": | |
| main() |