Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| from PIL import Image | |
| import spaces | |
| import librosa | |
| from gtts import gTTS | |
| import tempfile | |
| class CareBridgeTranslator: | |
| def __init__(self, model_id="google/translategemma-4b-it", device=None): | |
| """ | |
| Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility. | |
| """ | |
| self.model_id = model_id | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| self.model = None | |
| self.processor = None | |
| print(f"[SIMBOTI] Translator initialized. Model will load on first use.") | |
| # Top 10 NHS Languages Mapping (ISO 639-1) | |
| self.LANG_MAP = { | |
| "English": "en", | |
| "Polish": "pl", | |
| "Romanian": "ro", | |
| "Punjabi": "pa", | |
| "Urdu": "ur", | |
| "Portuguese": "pt", | |
| "Spanish": "es", | |
| "Arabic": "ar", | |
| "Bengali": "bn", | |
| "Gujarati": "gu", | |
| "Italian": "it" | |
| } | |
| def _load_model(self): | |
| if self.model is None: | |
| print(f"[SIMBOTI] Loading model {self.model_id}...") | |
| self.processor = AutoProcessor.from_pretrained(self.model_id) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| self.model_id, | |
| device_map=self.device, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| print("[SIMBOTI] Model loaded successfully.") | |
| def translate_text(self, text, source_lang_name, target_lang_name): | |
| """ | |
| Translate text ensuring patient data stays local. | |
| """ | |
| src_code = self.LANG_MAP.get(source_lang_name) | |
| tgt_code = self.LANG_MAP.get(target_lang_name) | |
| if not src_code or not tgt_code: | |
| return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}" | |
| message = { | |
| "role": "user", | |
| "content": [{ | |
| "type": "text", | |
| "source_lang_code": src_code, | |
| "target_lang_code": tgt_code, | |
| "text": text | |
| }] | |
| } | |
| return self._run_inference([message]) | |
| def translate_image(self, image_path, source_lang_name, target_lang_name): | |
| """ | |
| Extract and translate text from an image (e.g. instruction leaflet). | |
| """ | |
| src_code = self.LANG_MAP.get(source_lang_name) | |
| tgt_code = self.LANG_MAP.get(target_lang_name) | |
| if not src_code or not tgt_code: | |
| return f"Error: Language not supported." | |
| # Load image | |
| if isinstance(image_path, str): | |
| image = Image.open(image_path) | |
| else: | |
| image = image_path # Assume PIL object | |
| message = { | |
| "role": "user", | |
| "content": [{ | |
| "type": "image", | |
| "source_lang_code": src_code, | |
| "target_lang_code": tgt_code, | |
| "image": image | |
| }] | |
| } | |
| return self._run_inference([message]) | |
| def translate_audio(self, audio_path, source_lang_name, target_lang_name): | |
| """ | |
| Speech-to-Text Translation using Gemma 3 native audio support. | |
| """ | |
| src_code = self.LANG_MAP.get(source_lang_name) | |
| tgt_code = self.LANG_MAP.get(target_lang_name) | |
| if not src_code or not tgt_code: | |
| return "Error: Language not supported." | |
| # Load audio using librosa (Gemma 3 expects 16kHz usually) | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| message = { | |
| "role": "user", | |
| "content": [{ | |
| "type": "audio", | |
| "source_lang_code": src_code, | |
| "target_lang_code": tgt_code, | |
| "audio": audio | |
| }] | |
| } | |
| return self._run_inference([message]) | |
| def translate_video(self, video_path, source_lang_name, target_lang_name): | |
| """ | |
| Video OCR/Translation using Gemma 3 native video support. | |
| """ | |
| src_code = self.LANG_MAP.get(source_lang_name) | |
| tgt_code = self.LANG_MAP.get(target_lang_name) | |
| if not src_code or not tgt_code: | |
| return "Error: Language not supported." | |
| message = { | |
| "role": "user", | |
| "content": [{ | |
| "type": "video", | |
| "source_lang_code": src_code, | |
| "target_lang_code": tgt_code, | |
| "video": video_path | |
| }] | |
| } | |
| return self._run_inference([message]) | |
| def speak_text(self, text, lang_name): | |
| """ | |
| Generate audio from translated text for the patient. | |
| """ | |
| lang_code = self.LANG_MAP.get(lang_name, "en") | |
| try: | |
| tts = gTTS(text=text, lang=lang_code) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_file.name) | |
| return temp_file.name | |
| except Exception as e: | |
| print(f"TTS Error: {e}") | |
| return None | |
| def _run_inference(self, messages): | |
| self._load_model() | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate (Greedy for stability in medical context) | |
| with torch.no_grad(): | |
| outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False) | |
| # Decode response (Skipping input tokens) | |
| input_len = inputs["input_ids"].shape[-1] | |
| decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True) | |
| return decoded.strip() | |
| # Simple Verification Test if run directly | |
| if __name__ == "__main__": | |
| translator = CareBridgeTranslator() | |
| print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish")) |