|
|
import os
|
|
|
import torch
|
|
|
from transformers import AutoModelForImageTextToText, AutoProcessor
|
|
|
from PIL import Image
|
|
|
import librosa
|
|
|
from gtts import gTTS
|
|
|
import tempfile
|
|
|
|
|
|
|
|
|
try:
|
|
|
import spaces
|
|
|
except ImportError:
|
|
|
|
|
|
class spaces:
|
|
|
@staticmethod
|
|
|
def GPU(*args, **kwargs):
|
|
|
def decorator(func):
|
|
|
return func
|
|
|
return decorator
|
|
|
|
|
|
class CareBridgeTranslator:
|
|
|
def __init__(self, model_id="google/translategemma-4b-it", device=None):
|
|
|
"""
|
|
|
Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility.
|
|
|
"""
|
|
|
self.model_id = model_id
|
|
|
if device is None:
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
else:
|
|
|
self.device = device
|
|
|
|
|
|
self.model = None
|
|
|
self.processor = None
|
|
|
print(f"[SIMBOTI] Translator initialized. Model will load on first use.")
|
|
|
|
|
|
|
|
|
self.LANG_MAP = {
|
|
|
"English": "en",
|
|
|
"Polish": "pl",
|
|
|
"Romanian": "ro",
|
|
|
"Punjabi": "pa",
|
|
|
"Urdu": "ur",
|
|
|
"Portuguese": "pt",
|
|
|
"Spanish": "es",
|
|
|
"Arabic": "ar",
|
|
|
"Bengali": "bn",
|
|
|
"Gujarati": "gu",
|
|
|
"Italian": "it"
|
|
|
}
|
|
|
|
|
|
def _load_model(self):
|
|
|
if self.model is None:
|
|
|
print(f"[SIMBOTI] Loading model {self.model_id}...")
|
|
|
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
|
|
self.model_id,
|
|
|
device_map=self.device,
|
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
|
)
|
|
|
print("[SIMBOTI] Model loaded successfully.")
|
|
|
|
|
|
def translate_text(self, text, source_lang_name, target_lang_name):
|
|
|
"""
|
|
|
Translate text ensuring patient data stays local.
|
|
|
"""
|
|
|
src_code = self.LANG_MAP.get(source_lang_name)
|
|
|
tgt_code = self.LANG_MAP.get(target_lang_name)
|
|
|
|
|
|
if not src_code or not tgt_code:
|
|
|
return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}"
|
|
|
|
|
|
message = {
|
|
|
"role": "user",
|
|
|
"content": [{
|
|
|
"type": "text",
|
|
|
"source_lang_code": src_code,
|
|
|
"target_lang_code": tgt_code,
|
|
|
"text": text
|
|
|
}]
|
|
|
}
|
|
|
|
|
|
return self._run_inference([message])
|
|
|
|
|
|
def translate_image(self, image_path, source_lang_name, target_lang_name):
|
|
|
"""
|
|
|
Extract and translate text from an image (e.g. instruction leaflet).
|
|
|
"""
|
|
|
src_code = self.LANG_MAP.get(source_lang_name)
|
|
|
tgt_code = self.LANG_MAP.get(target_lang_name)
|
|
|
|
|
|
if not src_code or not tgt_code:
|
|
|
return f"Error: Language not supported."
|
|
|
|
|
|
|
|
|
if isinstance(image_path, str):
|
|
|
image = Image.open(image_path)
|
|
|
else:
|
|
|
image = image_path
|
|
|
|
|
|
message = {
|
|
|
"role": "user",
|
|
|
"content": [{
|
|
|
"type": "image",
|
|
|
"source_lang_code": src_code,
|
|
|
"target_lang_code": tgt_code,
|
|
|
"image": image
|
|
|
}]
|
|
|
}
|
|
|
|
|
|
return self._run_inference([message])
|
|
|
|
|
|
def translate_audio(self, audio_path, source_lang_name, target_lang_name):
|
|
|
"""
|
|
|
Speech-to-Text Translation using Gemma 3 native audio support.
|
|
|
"""
|
|
|
src_code = self.LANG_MAP.get(source_lang_name)
|
|
|
tgt_code = self.LANG_MAP.get(target_lang_name)
|
|
|
|
|
|
if not src_code or not tgt_code:
|
|
|
return "Error: Language not supported."
|
|
|
|
|
|
|
|
|
audio, sr = librosa.load(audio_path, sr=16000)
|
|
|
|
|
|
message = {
|
|
|
"role": "user",
|
|
|
"content": [{
|
|
|
"type": "audio",
|
|
|
"source_lang_code": src_code,
|
|
|
"target_lang_code": tgt_code,
|
|
|
"audio": audio
|
|
|
}]
|
|
|
}
|
|
|
|
|
|
return self._run_inference([message])
|
|
|
|
|
|
def translate_video(self, video_path, source_lang_name, target_lang_name):
|
|
|
"""
|
|
|
Video OCR/Translation using Gemma 3 native video support.
|
|
|
"""
|
|
|
src_code = self.LANG_MAP.get(source_lang_name)
|
|
|
tgt_code = self.LANG_MAP.get(target_lang_name)
|
|
|
|
|
|
if not src_code or not tgt_code:
|
|
|
return "Error: Language not supported."
|
|
|
|
|
|
message = {
|
|
|
"role": "user",
|
|
|
"content": [{
|
|
|
"type": "video",
|
|
|
"source_lang_code": src_code,
|
|
|
"target_lang_code": tgt_code,
|
|
|
"video": video_path
|
|
|
}]
|
|
|
}
|
|
|
|
|
|
return self._run_inference([message])
|
|
|
|
|
|
def speak_text(self, text, lang_name):
|
|
|
"""
|
|
|
Generate audio from translated text for the patient.
|
|
|
"""
|
|
|
lang_code = self.LANG_MAP.get(lang_name, "en")
|
|
|
try:
|
|
|
tts = gTTS(text=text, lang=lang_code)
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
|
|
|
tts.save(temp_file.name)
|
|
|
return temp_file.name
|
|
|
except Exception as e:
|
|
|
print(f"TTS Error: {e}")
|
|
|
return None
|
|
|
|
|
|
@spaces.GPU()
|
|
|
def _run_inference(self, messages):
|
|
|
self._load_model()
|
|
|
|
|
|
inputs = self.processor.apply_chat_template(
|
|
|
messages,
|
|
|
tokenize=True,
|
|
|
add_generation_prompt=True,
|
|
|
return_dict=True,
|
|
|
return_tensors="pt"
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
|
|
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
|
decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
|
|
|
return decoded.strip()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
translator = CareBridgeTranslator()
|
|
|
print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish"))
|
|
|
|