drrobot9 commited on
Commit
8ca4d6a
·
1 Parent(s): 7e7d098

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +62 -62
app/agents/crew_pipeline.py CHANGED
@@ -61,7 +61,7 @@ def detect_language(text: str, top_k: int = 1):
61
 
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
 
64
-
65
  translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
66
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
67
  config.TRANSLATION_MODEL_NAME,
@@ -69,18 +69,18 @@ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
69
  device_map="auto" if DEVICE == "cuda" else None
70
  )
71
 
 
 
72
 
73
  LANG_CODE_MAP = {
74
  "eng_Latn": "eng_Latn", # English
75
  "ibo_Latn": "ibo_Latn", # Igbo
76
  "yor_Latn": "yor_Latn", # Yoruba
77
  "hau_Latn": "hau_Latn", # Hausa
78
- "swh_Latn": "swh_Latn", # Swahili
79
- "amh_Latn": "amh_Latn", # Amharic
80
  }
81
 
82
-
83
-
84
  SUPPORTED_LANGS = {
85
  "eng_Latn": "English",
86
  "ibo_Latn": "Igbo",
@@ -113,44 +113,67 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
113
 
114
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
115
  """
116
- Translate text using the custom NLLB model directly
 
117
  """
 
 
 
118
  if not text.strip() or src_lang == tgt_lang:
 
119
  return text
120
 
121
  # Get language codes
122
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
123
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
124
 
 
 
125
 
126
- if not hasattr(translation_tokenizer, 'lang_code_to_id'):
127
- print("Warning: Tokenizer doesn't have lang_code_to_id attribute")
128
- print(f"Available tokenizer special tokens: {translation_tokenizer.special_tokens_map}")
 
129
 
130
- return translate_text_simple(text, src_lang, tgt_lang, max_chunk_len)
 
131
 
132
 
133
- if src_code not in translation_tokenizer.lang_code_to_id:
134
- print(f"Warning: Source language code '{src_code}' not found in tokenizer")
135
- src_code = "eng_Latn"
136
-
137
- if tgt_code not in translation_tokenizer.lang_code_to_id:
138
- print(f"Warning: Target language code '{tgt_code}' not found in tokenizer")
139
- tgt_code = "eng_Latn"
140
-
141
 
142
- translation_tokenizer.src_lang = src_code
143
-
 
 
144
 
145
- forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
 
 
 
 
 
 
 
 
 
146
 
147
  chunks = chunk_text(text, max_len=max_chunk_len)
148
  translated_parts = []
149
 
150
- for chunk in chunks:
 
 
151
  try:
152
 
153
- inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
154
 
155
  if DEVICE == "cuda":
156
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
@@ -158,7 +181,6 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
158
  # Generate translation
159
  generated_tokens = translation_model.generate(
160
  **inputs,
161
- forced_bos_token_id=forced_bos_token_id,
162
  max_new_tokens=400,
163
  num_beams=4,
164
  early_stopping=True
@@ -170,49 +192,20 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
170
  skip_special_tokens=True
171
  )[0]
172
 
173
- translated_parts.append(result)
174
-
175
- except Exception as e:
176
- print(f"Translation error ({src_code}->{tgt_code}): {e}")
177
 
178
- translated_parts.append(chunk)
179
-
180
- return " ".join(translated_parts).strip()
181
-
182
- def translate_text_simple(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
183
- """
184
- Simple fallback translation function if the main one fails
185
- """
186
- if not text.strip() or src_lang == tgt_lang:
187
- return text
188
-
189
- chunks = chunk_text(text, max_len=max_chunk_len)
190
- translated_parts = []
191
-
192
- for chunk in chunks:
193
- try:
194
-
195
- inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
196
 
197
- if DEVICE == "cuda":
198
- inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
199
 
200
- generated_tokens = translation_model.generate(
201
- **inputs,
202
- max_new_tokens=400
203
- )
204
-
205
- result = translation_tokenizer.batch_decode(
206
- generated_tokens,
207
- skip_special_tokens=True
208
- )[0]
209
-
210
- translated_parts.append(result)
211
  except Exception as e:
212
- print(f"Simple translation error: {e}")
213
- translated_parts.append(chunk)
214
 
215
- return " ".join(translated_parts).strip()
 
 
216
 
217
  # RAG retrieval
218
  def retrieve_docs(query: str, vs_path: str):
@@ -307,6 +300,8 @@ def run_pipeline(user_query: str, session_id: str = None):
307
  lang_label, prob = detect_language(user_query, top_k=1)[0]
308
  if lang_label not in SUPPORTED_LANGS:
309
  lang_label = "eng_Latn"
 
 
310
 
311
 
312
  translated_query = (
@@ -314,6 +309,8 @@ def run_pipeline(user_query: str, session_id: str = None):
314
  if lang_label != "eng_Latn"
315
  else user_query
316
  )
 
 
317
 
318
  intent, extra = detect_intent(translated_query)
319
 
@@ -363,6 +360,9 @@ def run_pipeline(user_query: str, session_id: str = None):
363
  else english_answer
364
  )
365
  final_answer = strip_markdown(final_answer)
 
 
 
366
  return {
367
  "session_id": session_id,
368
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
 
61
 
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
 
64
+ # Load tokenizer and model
65
  translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
66
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
67
  config.TRANSLATION_MODEL_NAME,
 
69
  device_map="auto" if DEVICE == "cuda" else None
70
  )
71
 
72
+ print(" Translation model loaded successfully")
73
+
74
 
75
  LANG_CODE_MAP = {
76
  "eng_Latn": "eng_Latn", # English
77
  "ibo_Latn": "ibo_Latn", # Igbo
78
  "yor_Latn": "yor_Latn", # Yoruba
79
  "hau_Latn": "hau_Latn", # Hausa
80
+ "swh_Latn": "swa_Latn", # Swahili
81
+ "amh_Latn": "amh_Ethi", # Amharic
82
  }
83
 
 
 
84
  SUPPORTED_LANGS = {
85
  "eng_Latn": "English",
86
  "ibo_Latn": "Igbo",
 
113
 
114
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
115
  """
116
+ Translate text using
117
+ IMPORTANT: Model expects format "src_lang text" -> "tgt_lang translation"
118
  """
119
+ print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
120
+ print(f" Input: {text[:100]}...")
121
+
122
  if not text.strip() or src_lang == tgt_lang:
123
+ print(" No translation needed (same language)")
124
  return text
125
 
126
  # Get language codes
127
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
128
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
129
 
130
+ print(f" Using codes: {src_code} → {tgt_code}")
131
+
132
 
133
+ if src_code != "eng_Latn" and tgt_code != "eng_Latn":
134
+ print(f" WARNING: Model wasn't trained on {src_code}→{tgt_code}")
135
+ print(f" Will translate {src_code}→eng_Latn→{tgt_code}")
136
+
137
 
138
+ to_english = translate_text_single(text, src_code, "eng_Latn", max_chunk_len)
139
+ return translate_text_single(to_english, "eng_Latn", tgt_code, max_chunk_len)
140
 
141
 
142
+ return translate_text_single(text, src_code, tgt_code, max_chunk_len)
 
 
 
 
 
 
 
143
 
144
+ def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
145
+ """
146
+ Perform single translation step
147
+ """
148
 
149
+ supported_pairs = [
150
+ ("eng_Latn", "ibo_Latn"), ("ibo_Latn", "eng_Latn"),
151
+ ("eng_Latn", "yor_Latn"), ("yor_Latn", "eng_Latn"),
152
+ ("eng_Latn", "hau_Latn"), ("hau_Latn", "eng_Latn"),
153
+ ("eng_Latn", "swa_Latn"), ("swa_Latn", "eng_Latn"),
154
+ ("eng_Latn", "amh_Ethi"), ("amh_Ethi", "eng_Latn"),
155
+ ]
156
+
157
+ if (src_code, tgt_code) not in supported_pairs:
158
+ print(f" WARNING: Pair {src_code}→{tgt_code} may not work well")
159
 
160
  chunks = chunk_text(text, max_len=max_chunk_len)
161
  translated_parts = []
162
 
163
+ for i, chunk in enumerate(chunks):
164
+ print(f" Chunk {i+1}/{len(chunks)}: '{chunk[:50]}...'")
165
+
166
  try:
167
 
168
+ input_text = f"{src_code} {chunk}"
169
+
170
+ # Tokenize
171
+ inputs = translation_tokenizer(
172
+ input_text,
173
+ return_tensors="pt",
174
+ truncation=True,
175
+ max_length=512
176
+ )
177
 
178
  if DEVICE == "cuda":
179
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
 
181
  # Generate translation
182
  generated_tokens = translation_model.generate(
183
  **inputs,
 
184
  max_new_tokens=400,
185
  num_beams=4,
186
  early_stopping=True
 
192
  skip_special_tokens=True
193
  )[0]
194
 
 
 
 
 
195
 
196
+ if result.startswith(tgt_code + " "):
197
+ result = result[len(tgt_code) + 1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ print(f" → '{result[:50]}...'")
200
+ translated_parts.append(result.strip())
201
 
 
 
 
 
 
 
 
 
 
 
 
202
  except Exception as e:
203
+ print(f" ERROR: {e}")
204
+ translated_parts.append(chunk) # Return original as fallback
205
 
206
+ final_result = " ".join(translated_parts).strip()
207
+ print(f" Final: '{final_result[:100]}...'")
208
+ return final_result
209
 
210
  # RAG retrieval
211
  def retrieve_docs(query: str, vs_path: str):
 
300
  lang_label, prob = detect_language(user_query, top_k=1)[0]
301
  if lang_label not in SUPPORTED_LANGS:
302
  lang_label = "eng_Latn"
303
+
304
+ print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
305
 
306
 
307
  translated_query = (
 
309
  if lang_label != "eng_Latn"
310
  else user_query
311
  )
312
+
313
+ print(f"Translated to English: {translated_query[:100]}...")
314
 
315
  intent, extra = detect_intent(translated_query)
316
 
 
360
  else english_answer
361
  )
362
  final_answer = strip_markdown(final_answer)
363
+
364
+ print(f"Final answer: {final_answer[:100]}...")
365
+
366
  return {
367
  "session_id": session_id,
368
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),