File size: 6,197 Bytes
9e0d577
 
 
 
 
 
 
 
 
 
8a50211
 
 
 
b3d2e15
8a50211
 
 
 
 
b3d2e15
 
8a50211
9e0d577
8a50211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d006165
 
 
 
 
 
 
 
 
 
 
b3d2e15
8a50211
 
 
b3d2e15
 
8a50211
b3d2e15
8a50211
 
 
 
 
 
 
 
 
 
 
 
9e0d577
 
 
8a50211
 
 
b3d2e15
9e0d577
8a50211
9e0d577
8a50211
 
 
9e0d577
b3d2e15
 
8a50211
 
 
 
 
 
 
 
 
 
 
 
9e0d577
 
 
8a50211
 
 
b3d2e15
9e0d577
8a50211
9e0d577
b3d2e15
8a50211
 
9e0d577
8a50211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e0d577
 
 
8a50211
 
 
b3d2e15
9e0d577
b3d2e15
9e0d577
 
 
b3d2e15
9e0d577
 
 
 
 
b3d2e15
8a50211
 
 
 
 
 
 
 
 
 
9e0d577
b3d2e15
8a50211
 
9e0d577
 
8a50211
 
 
 
 
d006165
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import spaces
import librosa
from gtts import gTTS
import tempfile

class CareBridgeTranslator:
    def __init__(self, model_id="google/translategemma-4b-it", device=None):
        """
        Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility.
        """
        self.model_id = model_id
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
            
        self.model = None
        self.processor = None
        print(f"[SIMBOTI] Translator initialized. Model will load on first use.")

        # Top 10 NHS Languages Mapping (ISO 639-1)
        self.LANG_MAP = {
            "English": "en",
            "Polish": "pl",
            "Romanian": "ro",
            "Punjabi": "pa",
            "Urdu": "ur",
            "Portuguese": "pt",
            "Spanish": "es",
            "Arabic": "ar",
            "Bengali": "bn",
            "Gujarati": "gu",
            "Italian": "it"
        }

    def _load_model(self):
        if self.model is None:
            print(f"[SIMBOTI] Loading model {self.model_id}...")
            self.processor = AutoProcessor.from_pretrained(self.model_id)
            self.model = AutoModelForImageTextToText.from_pretrained(
                self.model_id, 
                device_map=self.device, 
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            )
            print("[SIMBOTI] Model loaded successfully.")

    def translate_text(self, text, source_lang_name, target_lang_name):
        """
        Translate text ensuring patient data stays local.
        """
        src_code = self.LANG_MAP.get(source_lang_name)
        tgt_code = self.LANG_MAP.get(target_lang_name)
        
        if not src_code or not tgt_code:
            return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}"

        message = {
            "role": "user",
            "content": [{
                "type": "text", 
                "source_lang_code": src_code, 
                "target_lang_code": tgt_code, 
                "text": text
            }]
        }
        
        return self._run_inference([message])

    def translate_image(self, image_path, source_lang_name, target_lang_name):
        """
        Extract and translate text from an image (e.g. instruction leaflet).
        """
        src_code = self.LANG_MAP.get(source_lang_name)
        tgt_code = self.LANG_MAP.get(target_lang_name)
        
        if not src_code or not tgt_code:
            return f"Error: Language not supported."

        # Load image
        if isinstance(image_path, str):
            image = Image.open(image_path)
        else:
            image = image_path # Assume PIL object

        message = {
            "role": "user",
            "content": [{
                "type": "image", 
                "source_lang_code": src_code, 
                "target_lang_code": tgt_code, 
                "image": image 
            }]
        }
        
        return self._run_inference([message])

    def translate_audio(self, audio_path, source_lang_name, target_lang_name):
        """
        Speech-to-Text Translation using Gemma 3 native audio support.
        """
        src_code = self.LANG_MAP.get(source_lang_name)
        tgt_code = self.LANG_MAP.get(target_lang_name)
        
        if not src_code or not tgt_code:
            return "Error: Language not supported."

        # Load audio using librosa (Gemma 3 expects 16kHz usually)
        audio, sr = librosa.load(audio_path, sr=16000)

        message = {
            "role": "user",
            "content": [{
                "type": "audio", 
                "source_lang_code": src_code, 
                "target_lang_code": tgt_code, 
                "audio": audio 
            }]
        }
        
        return self._run_inference([message])

    def translate_video(self, video_path, source_lang_name, target_lang_name):
        """
        Video OCR/Translation using Gemma 3 native video support.
        """
        src_code = self.LANG_MAP.get(source_lang_name)
        tgt_code = self.LANG_MAP.get(target_lang_name)
        
        if not src_code or not tgt_code:
            return "Error: Language not supported."

        message = {
            "role": "user",
            "content": [{
                "type": "video", 
                "source_lang_code": src_code, 
                "target_lang_code": tgt_code, 
                "video": video_path 
            }]
        }
        
        return self._run_inference([message])

    def speak_text(self, text, lang_name):
        """
        Generate audio from translated text for the patient.
        """
        lang_code = self.LANG_MAP.get(lang_name, "en")
        try:
            tts = gTTS(text=text, lang=lang_code)
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
            tts.save(temp_file.name)
            return temp_file.name
        except Exception as e:
            print(f"TTS Error: {e}")
            return None

    @spaces.GPU()
    def _run_inference(self, messages):
        self._load_model()
        
        inputs = self.processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(self.device)

        # Generate (Greedy for stability in medical context)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
        
        # Decode response (Skipping input tokens)
        input_len = inputs["input_ids"].shape[-1]
        decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
        return decoded.strip()

# Simple Verification Test if run directly
if __name__ == "__main__":
    translator = CareBridgeTranslator()
    print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish"))