import os import torch from transformers import AutoModelForImageTextToText, AutoProcessor from PIL import Image import spaces import librosa from gtts import gTTS import tempfile class CareBridgeTranslator: def __init__(self, model_id="google/translategemma-4b-it", device=None): """ Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility. """ self.model_id = model_id if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.model = None self.processor = None print(f"[SIMBOTI] Translator initialized. Model will load on first use.") # Top 10 NHS Languages Mapping (ISO 639-1) self.LANG_MAP = { "English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it" } def _load_model(self): if self.model is None: print(f"[SIMBOTI] Loading model {self.model_id}...") self.processor = AutoProcessor.from_pretrained(self.model_id) self.model = AutoModelForImageTextToText.from_pretrained( self.model_id, device_map=self.device, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) print("[SIMBOTI] Model loaded successfully.") def translate_text(self, text, source_lang_name, target_lang_name): """ Translate text ensuring patient data stays local. """ src_code = self.LANG_MAP.get(source_lang_name) tgt_code = self.LANG_MAP.get(target_lang_name) if not src_code or not tgt_code: return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}" message = { "role": "user", "content": [{ "type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text }] } return self._run_inference([message]) def translate_image(self, image_path, source_lang_name, target_lang_name): """ Extract and translate text from an image (e.g. instruction leaflet). """ src_code = self.LANG_MAP.get(source_lang_name) tgt_code = self.LANG_MAP.get(target_lang_name) if not src_code or not tgt_code: return f"Error: Language not supported." # Load image if isinstance(image_path, str): image = Image.open(image_path) else: image = image_path # Assume PIL object message = { "role": "user", "content": [{ "type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image }] } return self._run_inference([message]) def translate_audio(self, audio_path, source_lang_name, target_lang_name): """ Speech-to-Text Translation using Gemma 3 native audio support. """ src_code = self.LANG_MAP.get(source_lang_name) tgt_code = self.LANG_MAP.get(target_lang_name) if not src_code or not tgt_code: return "Error: Language not supported." # Load audio using librosa (Gemma 3 expects 16kHz usually) audio, sr = librosa.load(audio_path, sr=16000) message = { "role": "user", "content": [{ "type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio }] } return self._run_inference([message]) def translate_video(self, video_path, source_lang_name, target_lang_name): """ Video OCR/Translation using Gemma 3 native video support. """ src_code = self.LANG_MAP.get(source_lang_name) tgt_code = self.LANG_MAP.get(target_lang_name) if not src_code or not tgt_code: return "Error: Language not supported." message = { "role": "user", "content": [{ "type": "video", "source_lang_code": src_code, "target_lang_code": tgt_code, "video": video_path }] } return self._run_inference([message]) def speak_text(self, text, lang_name): """ Generate audio from translated text for the patient. """ lang_code = self.LANG_MAP.get(lang_name, "en") try: tts = gTTS(text=text, lang=lang_code) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") tts.save(temp_file.name) return temp_file.name except Exception as e: print(f"TTS Error: {e}") return None @spaces.GPU() def _run_inference(self, messages): self._load_model() inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(self.device) # Generate (Greedy for stability in medical context) with torch.no_grad(): outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False) # Decode response (Skipping input tokens) input_len = inputs["input_ids"].shape[-1] decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True) return decoded.strip() # Simple Verification Test if run directly if __name__ == "__main__": translator = CareBridgeTranslator() print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish"))