NurseCitizenDeveloper commited on
Commit
8a50211
·
verified ·
1 Parent(s): c856552

Update carebridge_client.py

Browse files
Files changed (1) hide show
  1. carebridge_client.py +126 -12
carebridge_client.py CHANGED
@@ -8,50 +8,148 @@ from gtts import gTTS
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
- def __init__(self, model_id="google/translategemma-4b-it"):
 
 
 
12
  self.model_id = model_id
 
 
 
 
 
13
  self.model = None
14
  self.processor = None
15
- self.LANG_MAP = {"English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it"}
16
- print("[SIMBOTI] Translator initialized. Model will load on first use.")
17
 
18
  def _load_model(self):
19
  if self.model is None:
20
  print(f"[SIMBOTI] Loading model {self.model_id}...")
21
  self.processor = AutoProcessor.from_pretrained(self.model_id)
22
- self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, device_map="cuda", torch_dtype=torch.float16)
 
 
 
 
23
  print("[SIMBOTI] Model loaded successfully.")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def translate_text(self, text, source_lang_name, target_lang_name):
 
 
 
26
  src_code = self.LANG_MAP.get(source_lang_name)
27
  tgt_code = self.LANG_MAP.get(target_lang_name)
 
28
  if not src_code or not tgt_code:
29
- return "Error: Language not supported."
30
- message = {"role": "user", "content": [{"type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text}]}
 
 
 
 
 
 
 
 
 
 
31
  return self._run_inference([message])
32
 
33
  def translate_image(self, image_path, source_lang_name, target_lang_name):
 
 
 
34
  src_code = self.LANG_MAP.get(source_lang_name)
35
  tgt_code = self.LANG_MAP.get(target_lang_name)
 
36
  if not src_code or not tgt_code:
37
- return "Error: Language not supported."
 
 
38
  if isinstance(image_path, str):
39
  image = Image.open(image_path)
40
  else:
41
- image = image_path
42
- message = {"role": "user", "content": [{"type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image}]}
 
 
 
 
 
 
 
 
 
 
43
  return self._run_inference([message])
44
 
45
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
 
 
 
46
  src_code = self.LANG_MAP.get(source_lang_name)
47
  tgt_code = self.LANG_MAP.get(target_lang_name)
 
48
  if not src_code or not tgt_code:
49
  return "Error: Language not supported."
 
 
50
  audio, sr = librosa.load(audio_path, sr=16000)
51
- message = {"role": "user", "content": [{"type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return self._run_inference([message])
53
 
54
  def speak_text(self, text, lang_name):
 
 
 
55
  lang_code = self.LANG_MAP.get(lang_name, "en")
56
  try:
57
  tts = gTTS(text=text, lang=lang_code)
@@ -65,9 +163,25 @@ class CareBridgeTranslator:
65
  @spaces.GPU()
66
  def _run_inference(self, messages):
67
  self._load_model()
68
- inputs = self.processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
 
69
  with torch.no_grad():
70
  outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
 
 
71
  input_len = inputs["input_ids"].shape[-1]
72
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
73
- return decoded.strip()
 
 
 
 
 
 
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
+ def __init__(self, model_id="google/translategemma-4b-it", device=None):
12
+ """
13
+ Initialize the CareBridge Translator with lazy loading for ZeroGPU compatibility.
14
+ """
15
  self.model_id = model_id
16
+ if device is None:
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ else:
19
+ self.device = device
20
+
21
  self.model = None
22
  self.processor = None
23
+ print(f"[SIMBOTI] Translator initialized. Model will load on first use.")
 
24
 
25
  def _load_model(self):
26
  if self.model is None:
27
  print(f"[SIMBOTI] Loading model {self.model_id}...")
28
  self.processor = AutoProcessor.from_pretrained(self.model_id)
29
+ self.model = AutoModelForImageTextToText.from_pretrained(
30
+ self.model_id,
31
+ device_map=self.device,
32
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
33
+ )
34
  print("[SIMBOTI] Model loaded successfully.")
35
 
36
+ # Top 10 NHS Languages Mapping (ISO 639-1)
37
+ self.LANG_MAP = {
38
+
39
+ "English": "en",
40
+ "Polish": "pl",
41
+ "Romanian": "ro",
42
+ "Punjabi": "pa",
43
+ "Urdu": "ur",
44
+ "Portuguese": "pt",
45
+ "Spanish": "es",
46
+ "Arabic": "ar",
47
+ "Bengali": "bn",
48
+ "Gujarati": "gu",
49
+ "Italian": "it"
50
+ }
51
+
52
  def translate_text(self, text, source_lang_name, target_lang_name):
53
+ """
54
+ Translate text ensuring patient data stays local.
55
+ """
56
  src_code = self.LANG_MAP.get(source_lang_name)
57
  tgt_code = self.LANG_MAP.get(target_lang_name)
58
+
59
  if not src_code or not tgt_code:
60
+ return f"Error: Language not supported. Available: {list(self.LANG_MAP.keys())}"
61
+
62
+ message = {
63
+ "role": "user",
64
+ "content": [{
65
+ "type": "text",
66
+ "source_lang_code": src_code,
67
+ "target_lang_code": tgt_code,
68
+ "text": text
69
+ }]
70
+ }
71
+
72
  return self._run_inference([message])
73
 
74
  def translate_image(self, image_path, source_lang_name, target_lang_name):
75
+ """
76
+ Extract and translate text from an image (e.g. instruction leaflet).
77
+ """
78
  src_code = self.LANG_MAP.get(source_lang_name)
79
  tgt_code = self.LANG_MAP.get(target_lang_name)
80
+
81
  if not src_code or not tgt_code:
82
+ return f"Error: Language not supported."
83
+
84
+ # Load image
85
  if isinstance(image_path, str):
86
  image = Image.open(image_path)
87
  else:
88
+ image = image_path # Assume PIL object
89
+
90
+ message = {
91
+ "role": "user",
92
+ "content": [{
93
+ "type": "image",
94
+ "source_lang_code": src_code,
95
+ "target_lang_code": tgt_code,
96
+ "image": image
97
+ }]
98
+ }
99
+
100
  return self._run_inference([message])
101
 
102
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
103
+ """
104
+ Speech-to-Text Translation using Gemma 3 native audio support.
105
+ """
106
  src_code = self.LANG_MAP.get(source_lang_name)
107
  tgt_code = self.LANG_MAP.get(target_lang_name)
108
+
109
  if not src_code or not tgt_code:
110
  return "Error: Language not supported."
111
+
112
+ # Load audio using librosa (Gemma 3 expects 16kHz usually)
113
  audio, sr = librosa.load(audio_path, sr=16000)
114
+
115
+ message = {
116
+ "role": "user",
117
+ "content": [{
118
+ "type": "audio",
119
+ "source_lang_code": src_code,
120
+ "target_lang_code": tgt_code,
121
+ "audio": audio
122
+ }]
123
+ }
124
+
125
+ return self._run_inference([message])
126
+
127
+ def translate_video(self, video_path, source_lang_name, target_lang_name):
128
+ """
129
+ Video OCR/Translation using Gemma 3 native video support.
130
+ """
131
+ src_code = self.LANG_MAP.get(source_lang_name)
132
+ tgt_code = self.LANG_MAP.get(target_lang_name)
133
+
134
+ if not src_code or not tgt_code:
135
+ return "Error: Language not supported."
136
+
137
+ message = {
138
+ "role": "user",
139
+ "content": [{
140
+ "type": "video",
141
+ "source_lang_code": src_code,
142
+ "target_lang_code": tgt_code,
143
+ "video": video_path
144
+ }]
145
+ }
146
+
147
  return self._run_inference([message])
148
 
149
  def speak_text(self, text, lang_name):
150
+ """
151
+ Generate audio from translated text for the patient.
152
+ """
153
  lang_code = self.LANG_MAP.get(lang_name, "en")
154
  try:
155
  tts = gTTS(text=text, lang=lang_code)
 
163
  @spaces.GPU()
164
  def _run_inference(self, messages):
165
  self._load_model()
166
+
167
+ inputs = self.processor.apply_chat_template(
168
+ messages,
169
+ tokenize=True,
170
+ add_generation_prompt=True,
171
+ return_dict=True,
172
+ return_tensors="pt"
173
+ ).to(self.device)
174
+
175
+ # Generate (Greedy for stability in medical context)
176
  with torch.no_grad():
177
  outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
178
+
179
+ # Decode response (Skipping input tokens)
180
  input_len = inputs["input_ids"].shape[-1]
181
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
182
+ return decoded.strip()
183
+
184
+ # Simple Verification Test if run directly
185
+ if __name__ == "__main__":
186
+ translator = CareBridgeTranslator()
187
+ print("Test 1 (Text):", translator.translate_text("Where does it hurt?", "English", "Polish"))