New_folder_2 / google_translate.py
zxc4wewewe's picture
Upload 12 files
621ec47 verified
# Defer torch import to avoid CUDA initialization issues
# torch will be imported when needed in the _load_model method
from typing import List, Dict, Union, Optional
import logging
from PIL import Image
import requests
import os
import tempfile
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GoogleTranslateGemma:
"""
Google Translate Gemma model wrapper for text and image translation.
This class provides an interface to the Google TranslateGemma model for:
- Text translation between languages
- Text extraction and translation from images
"""
def __init__(self, model_id: str = "google/translategemma-12b-it"):
"""
Initialize the Google Translate Gemma model.
Args:
model_id (str): The model identifier from Hugging Face
"""
self.model_id = model_id
self.model = None
self.processor = None
self.device = None # Will be set when torch is imported
self._load_model()
def _load_model(self):
"""Load the model using direct approach."""
try:
# Import torch here to avoid CUDA initialization issues
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
logger.info(f"Loading model: {self.model_id}")
self.processor = AutoProcessor.from_pretrained(self.model_id)
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
device_map="auto"
)
self.device = self.model.device
logger.info(f"Model loaded successfully on device: {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def translate_text(
self,
text: str,
source_lang: str,
target_lang: str,
max_new_tokens: int = 200
) -> str:
"""
Translate text from source language to target language.
Args:
text (str): The text to translate
source_lang (str): Source language code (e.g., 'cs' for Czech)
target_lang (str): Target language code (e.g., 'de-DE' for German)
max_new_tokens (int): Maximum number of tokens to generate
Returns:
str: The translated text
"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"text": text,
}
],
}
]
try:
# Import torch here if not already imported
import torch
# Use direct model approach
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(self.device, dtype=torch.bfloat16)
input_len = len(inputs['input_ids'][0])
with torch.inference_mode():
generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generation = generation[0][input_len:]
decoded = self.processor.decode(generation, skip_special_tokens=True)
return decoded
except Exception as e:
logger.error(f"Translation failed: {str(e)}")
raise
def translate_image(
self,
image_input: Union[str, Image.Image],
source_lang: str,
target_lang: str,
max_new_tokens: int = 200
) -> str:
"""
Extract text from an image and translate it to the target language.
Args:
image_input (Union[str, Image.Image]): URL or PIL Image object containing text
source_lang (str): Source language code (e.g., 'cs' for Czech)
target_lang (str): Target language code (e.g., 'de-DE' for German)
max_new_tokens (int): Maximum number of tokens to generate
Returns:
str: The extracted and translated text
"""
# Handle local image files
if isinstance(image_input, str) and os.path.exists(image_input):
# It's a local file path
image = Image.open(image_input)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"image": image,
},
],
}
]
return self._translate_with_messages(messages, max_new_tokens)
# Handle PIL Image objects
elif isinstance(image_input, Image.Image):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"image": image_input,
},
],
}
]
return self._translate_with_messages(messages, max_new_tokens)
# Handle URLs
else:
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"source_lang_code": source_lang,
"target_lang_code": target_lang,
"url": image_input,
},
],
}
]
return self._translate_with_messages(messages, max_new_tokens)
def _translate_with_messages(self, messages: List[Dict], max_new_tokens: int = 200) -> str:
"""
Helper method to translate using messages with direct model.
Args:
messages (List[Dict]): Formatted messages for the model
max_new_tokens (int): Maximum number of tokens to generate
Returns:
str: The translated text
"""
try:
# Import torch here if not already imported
import torch
# Use direct model approach
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(self.device, dtype=torch.bfloat16)
input_len = len(inputs['input_ids'][0])
with torch.inference_mode():
generation = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generation = generation[0][input_len:]
decoded = self.processor.decode(generation, skip_special_tokens=True)
return decoded
except Exception as e:
logger.error(f"Translation failed: {str(e)}")
raise
# Example usage and testing functions
def test_text_translation():
"""Test text translation functionality."""
print("Testing text translation...")
translator = GoogleTranslateGemma()
# Example: Czech to German
source_text = "V nejhorším případě i k prasknutí čočky."
source_lang = "cs"
target_lang = "de-DE"
try:
translated = translator.translate_text(
text=source_text,
source_lang=source_lang,
target_lang=target_lang
)
print(f"Source ({source_lang}): {source_text}")
print(f"Target ({target_lang}): {translated}")
print("-" * 50)
except Exception as e:
print(f"Text translation test failed: {str(e)}")
def test_image_translation():
"""Test image translation functionality."""
print("Testing image translation...")
translator = GoogleTranslateGemma()
# Example: Czech traffic sign to German
image_url = "https://c7.alamy.com/comp/2YAX36N/traffic-signs-in-czech-republic-pedestrian-zone-2YAX36N.jpg"
source_lang = "cs"
target_lang = "de-DE"
try:
translated = translator.translate_image(
image_url=image_url,
source_lang=source_lang,
target_lang=target_lang
)
print(f"Image URL: {image_url}")
print(f"Source ({source_lang}): [Text extracted from image]")
print(f"Target ({target_lang}): {translated}")
print("-" * 50)
except Exception as e:
print(f"Image translation test failed: {str(e)}")
def main():
"""Main function to run example translations."""
print("Google Translate Gemma Module")
print("=" * 50)
# Run tests
test_text_translation()
test_image_translation()
print("Example completed!")
if __name__ == "__main__":
main()