NurseCitizenDeveloper commited on
Commit
6c3dfac
·
verified ·
1 Parent(s): 9e0d577

Update carebridge_client.py

Browse files
Files changed (1) hide show
  1. carebridge_client.py +31 -28
carebridge_client.py CHANGED
@@ -8,63 +8,66 @@ from gtts import gTTS
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
- def __init__(self, model_id="google/translategemma-4b-it", device=None):
12
- self.model_id = model_id
13
- if device is None:
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
- else:
16
- self.device = device
17
- print(f"[SIMBOTI] Loading model {model_id}...")
18
- self.processor = AutoProcessor.from_pretrained(model_id)
19
- self.model = AutoModelForImageTextToText.from_pretrained(model_id, device_map=self.device, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32)
20
- print("[SIMBOTI] Model loaded successfully.")
21
- self.LANG_MAP = {"English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it"}
22
 
23
- def translate_text(self, text, source_lang_name, target_lang_name):
24
- src_code = self.LANG_MAP.get(source_lang_name)
25
- tgt_code = self.LANG_MAP.get(target_lang_name)
26
- if not src_code or not tgt_code:
27
- return f"Error: Language not supported."
28
- message = {"role": "user", "content": [{"type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text}]}
 
 
 
 
 
 
 
29
  return self._run_inference([message])
30
 
31
  def translate_image(self, image_path, source_lang_name, target_lang_name):
32
- src_code = self.LANG_MAP.get(source_lang_name)
33
  tgt_code = self.LANG_MAP.get(target_lang_name)
34
  if not src_code or not tgt_code:
35
- return "Error: Language not supported."
36
  if isinstance(image_path, str):
37
- image = Image.open(image_path)
38
- else:
39
  image = image_path
40
  message = {"role": "user", "content": [{"type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image}]}
41
  return self._run_inference([message])
42
 
43
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
44
- src_code = self.LANG_MAP.get(source_lang_name)
45
  tgt_code = self.LANG_MAP.get(target_lang_name)
46
  if not src_code or not tgt_code:
47
- return "Error: Language not supported."
48
  audio, sr = librosa.load(audio_path, sr=16000)
49
  message = {"role": "user", "content": [{"type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio}]}
50
  return self._run_inference([message])
51
 
52
  def speak_text(self, text, lang_name):
53
- lang_code = self.LANG_MAP.get(lang_name, "en")
54
  try:
55
- tts = gTTS(text=text, lang=lang_code)
56
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
57
  tts.save(temp_file.name)
58
  return temp_file.name
59
- except Exception as e:
60
  print(f"TTS Error: {e}")
61
  return None
62
 
63
  @spaces.GPU()
64
  def _run_inference(self, messages):
65
- inputs = self.processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.device)
 
66
  with torch.no_grad():
67
- outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
68
  input_len = inputs["input_ids"].shape[-1]
69
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
70
  return decoded.strip()
 
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
+ def __init__(self, model_id="google/translategemma-4b-it"):
12
+ self.model_id = model_id
13
+ self.model = None
14
+ self.processor = None
15
+ self.LANG_MAP = {"English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it"}
16
+ print("[SIMBOTI] Translator initialized. Model will load on first use.")
 
 
 
 
 
17
 
18
+ def _load_model(self):
19
+ if self.model is None:
20
+ print(f"[SIMBOTI] Loading model {self.model_id}...")
21
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
22
+ self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, device_map="cuda", torch_dtype=torch.float16)
23
+ print("[SIMBOTI] Model loaded successfully.")
24
+
25
+ def translate_text(self, text, source_lang_name, target_lang_name):
26
+ src_code = self.LANG_MAP.get(source_lang_name)
27
+ tgt_code = self.LANG_MAP.get(target_lang_name)
28
+ if not src_code or not tgt_code:
29
+ return "Error: Language not supported."
30
+ message = {"role": "user", "content": [{"type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text}]}
31
  return self._run_inference([message])
32
 
33
  def translate_image(self, image_path, source_lang_name, target_lang_name):
34
+ src_code = self.LANG_MAP.get(source_lang_name)
35
  tgt_code = self.LANG_MAP.get(target_lang_name)
36
  if not src_code or not tgt_code:
37
+ return "Error: Language not supported."
38
  if isinstance(image_path, str):
39
+ image = Image.open(image_path)
40
+ else:
41
  image = image_path
42
  message = {"role": "user", "content": [{"type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image}]}
43
  return self._run_inference([message])
44
 
45
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
46
+ src_code = self.LANG_MAP.get(source_lang_name)
47
  tgt_code = self.LANG_MAP.get(target_lang_name)
48
  if not src_code or not tgt_code:
49
+ return "Error: Language not supported."
50
  audio, sr = librosa.load(audio_path, sr=16000)
51
  message = {"role": "user", "content": [{"type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio}]}
52
  return self._run_inference([message])
53
 
54
  def speak_text(self, text, lang_name):
55
+ lang_code = self.LANG_MAP.get(lang_name, "en")
56
  try:
57
+ tts = gTTS(text=text, lang=lang_code)
58
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
59
  tts.save(temp_file.name)
60
  return temp_file.name
61
+ except Exception as e:
62
  print(f"TTS Error: {e}")
63
  return None
64
 
65
  @spaces.GPU()
66
  def _run_inference(self, messages):
67
+ self._load_model()
68
+ inputs = self.processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
69
  with torch.no_grad():
70
+ outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
71
  input_len = inputs["input_ids"].shape[-1]
72
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
73
  return decoded.strip()