NurseCitizenDeveloper commited on
Commit
b3d2e15
·
verified ·
1 Parent(s): 6c3dfac

Update carebridge_client.py

Browse files
Files changed (1) hide show
  1. carebridge_client.py +30 -30
carebridge_client.py CHANGED
@@ -8,66 +8,66 @@ from gtts import gTTS
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
- def __init__(self, model_id="google/translategemma-4b-it"):
12
- self.model_id = model_id
13
- self.model = None
14
- self.processor = None
15
- self.LANG_MAP = {"English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it"}
16
- print("[SIMBOTI] Translator initialized. Model will load on first use.")
17
 
18
- def _load_model(self):
19
- if self.model is None:
20
- print(f"[SIMBOTI] Loading model {self.model_id}...")
21
- self.processor = AutoProcessor.from_pretrained(self.model_id)
22
- self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, device_map="cuda", torch_dtype=torch.float16)
23
- print("[SIMBOTI] Model loaded successfully.")
24
 
25
- def translate_text(self, text, source_lang_name, target_lang_name):
26
- src_code = self.LANG_MAP.get(source_lang_name)
27
- tgt_code = self.LANG_MAP.get(target_lang_name)
28
- if not src_code or not tgt_code:
29
- return "Error: Language not supported."
30
- message = {"role": "user", "content": [{"type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text}]}
31
  return self._run_inference([message])
32
 
33
  def translate_image(self, image_path, source_lang_name, target_lang_name):
34
- src_code = self.LANG_MAP.get(source_lang_name)
35
  tgt_code = self.LANG_MAP.get(target_lang_name)
36
  if not src_code or not tgt_code:
37
- return "Error: Language not supported."
38
  if isinstance(image_path, str):
39
- image = Image.open(image_path)
40
- else:
41
  image = image_path
42
  message = {"role": "user", "content": [{"type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image}]}
43
  return self._run_inference([message])
44
 
45
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
46
- src_code = self.LANG_MAP.get(source_lang_name)
47
  tgt_code = self.LANG_MAP.get(target_lang_name)
48
  if not src_code or not tgt_code:
49
- return "Error: Language not supported."
50
  audio, sr = librosa.load(audio_path, sr=16000)
51
  message = {"role": "user", "content": [{"type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio}]}
52
  return self._run_inference([message])
53
 
54
  def speak_text(self, text, lang_name):
55
- lang_code = self.LANG_MAP.get(lang_name, "en")
56
  try:
57
- tts = gTTS(text=text, lang=lang_code)
58
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
59
  tts.save(temp_file.name)
60
  return temp_file.name
61
- except Exception as e:
62
  print(f"TTS Error: {e}")
63
  return None
64
 
65
  @spaces.GPU()
66
  def _run_inference(self, messages):
67
- self._load_model()
68
  inputs = self.processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
69
  with torch.no_grad():
70
- outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
71
  input_len = inputs["input_ids"].shape[-1]
72
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
73
- return decoded.strip()
 
8
  import tempfile
9
 
10
  class CareBridgeTranslator:
11
+ def __init__(self, model_id="google/translategemma-4b-it"):
12
+ self.model_id = model_id
13
+ self.model = None
14
+ self.processor = None
15
+ self.LANG_MAP = {"English": "en", "Polish": "pl", "Romanian": "ro", "Punjabi": "pa", "Urdu": "ur", "Portuguese": "pt", "Spanish": "es", "Arabic": "ar", "Bengali": "bn", "Gujarati": "gu", "Italian": "it"}
16
+ print("[SIMBOTI] Translator initialized. Model will load on first use.")
17
 
18
+ def _load_model(self):
19
+ if self.model is None:
20
+ print(f"[SIMBOTI] Loading model ${self.model_id}...")
21
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
22
+ self.model = AutoModelForImageTextToText.from_pretrained(self.model_id, device_map="cuda", torch_dtype=torch.float16)
23
+ print("[SIMBOTI] Model loaded successfully.")
24
 
25
+ def translate_text(self, text, source_lang_name, target_lang_name):
26
+ src_code = self.LANG_MAP.get(source_lang_name)
27
+ tgt_code = self.LANG_MAP.get(target_lang_name)
28
+ if not src_code or not tgt_code:
29
+ return "Error: Language not supported."
30
+ message = {"role": "user", "content": [{"type": "text", "source_lang_code": src_code, "target_lang_code": tgt_code, "text": text}]}
31
  return self._run_inference([message])
32
 
33
  def translate_image(self, image_path, source_lang_name, target_lang_name):
34
+ src_code = self.LANG_MAP.get(source_lang_name)
35
  tgt_code = self.LANG_MAP.get(target_lang_name)
36
  if not src_code or not tgt_code:
37
+ return "Error: Language not supported."
38
  if isinstance(image_path, str):
39
+ image = Image.open(image_path)
40
+ else:
41
  image = image_path
42
  message = {"role": "user", "content": [{"type": "image", "source_lang_code": src_code, "target_lang_code": tgt_code, "image": image}]}
43
  return self._run_inference([message])
44
 
45
  def translate_audio(self, audio_path, source_lang_name, target_lang_name):
46
+ src_code = self.LANG_MAP.get(source_lang_name)
47
  tgt_code = self.LANG_MAP.get(target_lang_name)
48
  if not src_code or not tgt_code:
49
+ return "Error: Language not supported."
50
  audio, sr = librosa.load(audio_path, sr=16000)
51
  message = {"role": "user", "content": [{"type": "audio", "source_lang_code": src_code, "target_lang_code": tgt_code, "audio": audio}]}
52
  return self._run_inference([message])
53
 
54
  def speak_text(self, text, lang_name):
55
+ lang_code = self.LANG_MAP.get(lang_name, "en")
56
  try:
57
+ tts = gTTS(text=text, lang=lang_code)
58
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
59
  tts.save(temp_file.name)
60
  return temp_file.name
61
+ except Exception as e:
62
  print(f"TTS Error: {e}")
63
  return None
64
 
65
  @spaces.GPU()
66
  def _run_inference(self, messages):
67
+ self._load_model()
68
  inputs = self.processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
69
  with torch.no_grad():
70
+ outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=False)
71
  input_len = inputs["input_ids"].shape[-1]
72
  decoded = self.processor.decode(outputs[0][input_len:], skip_special_tokens=True)
73
+ return decoded.strip()