SIMBOTI-Live / carebridge_client.py
NurseCitizenDeveloper's picture
Upload folder using huggingface_hub
becb41b verified
import os
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import librosa
from gtts import gTTS
import tempfile
# Hugging Face Spaces GPU decorator (optional for local development)
try:
import spaces
except ImportError:
# Fallback: no-op decorator for local testing
class spaces:
@staticmethod
def GPU(*args, **kwargs):
def decorator(func):
return func
return decorator
class CareBridgeTranslator:
def __init__(self, model_id="google/translategemma-4b-it", device=None):
"""
Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility.
"""
self.model_id = model_id
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model = None
self.processor = None
print(f"[SIMBOTI] Translator initialized. Model will load on first use.")
# Top 10 NHS Languages Mapping (ISO 639-1)
self.LANG_MAP = {
"English": "en",
"Polish": "pl",
"Romanian": "ro",
"Punjabi": "pa",
"Urdu": "ur",
"Portuguese": "pt",
"Spanish": "es",
"Arabic": "ar",
"Bengali": "bn",
"Gujarati": "gu",
"Italian": "it"
}
def _load_model(self):
if self.model is None:
print(f"[SIMBOTI] Loading model {self.model_id}...")
self.processor = AutoProcessor.from_pretrained(self.model_id)
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
device_map=self.device,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
print("[SIMBOTI] Model loaded successfully.")
def translate_text(self, text, source_lang_name, target_lang_name):
"""
Translate text ensuring patient data stays local.
"""
src_code = self.LANG_MAP.get(source_lang_name)
tgt_code = self.LANG_MAP.get(target_lang_name)
if not src_code or not tgt_code:
return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}"
message = {
"role": "user",
"content": [{
"type": "text",
"source_lang_code": src_code,
"target_lang_code": tgt_code,
"text": text
}]
}
return self._run_inference([message])
def translate_image(self, image_path, source_lang_name, target_lang_name):
"""
Extract and translate text from an image (e.g. instruction leaflet).
"""
src_code = self.LANG_MAP.get(source_lang_name)
tgt_code = self.LANG_MAP.get(target_lang_name)
if not src_code or not tgt_code:
return f"Error: Language not supported."
# Load image
if isinstance(image_path, str):
image = Image.open(image_path)
else:
image = image_path # Assume PIL object
message = {
"role": "user",
"content": [{
"type": "image",
"source_lang_code": src_code,
"target_lang_code": tgt_code,
"image": image
}]
}
return self._run_inference([message])
def translate_audio(self, audio_path, source_lang_name, target_lang_name):
"""
Speech-to-Text Translation using Gemma 3 native audio support.
"""
src_code = self.LANG_MAP.get(source_lang_name)
tgt_code = self.LANG_MAP.get(target_lang_name)
if not src_code or not tgt_code:
return "Error: Language not supported."
# Load audio using librosa (Gemma 3 expects 16kHz usually)
audio, sr = librosa.load(audio_path, sr=16000)
message = {
"role": "user",
"content": [{
"type": "audio",
"source_lang_code": src_code,
"target_lang_code": tgt_code,
"audio": audio
}]
}
return self._run_inference([message])
def translate_video(self, video_path, source_lang_name, target_lang_name):
"""
Video OCR/Translation using Gemma 3 native video support.
"""
src_code = self.LANG_MAP.get(source_lang_name)
tgt_code = self.LANG_MAP.get(target_lang_name)
if not src_code or not tgt_code:
return "Error: Language not supported."
message = {
"role": "user",
"content": [{
"type": "video",
"source_lang_code": src_code,
"target_lang_code": tgt_code,
"video": video_path
}]
}
return self._run_inference([message])
def speak_text(self, text, lang_name):
"""
Generate audio from translated text for the patient.
"""
lang_code = self.LANG_MAP.get(lang_name, "en")
try:
tts = gTTS(text=text, lang=lang_code)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_file.name)
return temp_file.name
except Exception as e:
print(f"TTS Error: {e}")
return None
@spaces.GPU()
def _run_inference(self, messages):
self._load_model()
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(self.device)
# Generate (Greedy for stability in medical context)
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
# Decode response (Skipping input tokens)
input_len = inputs["input_ids"].shape[-1]
decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
return decoded.strip()
# Simple Verification Test if run directly
if __name__ == "__main__":
translator = CareBridgeTranslator()
print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish"))