drrobot9 commited on
Commit
a3aa01e
·
1 Parent(s): 8ca4d6a

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +11 -41
app/agents/crew_pipeline.py CHANGED
@@ -43,7 +43,6 @@ model = AutoModelForCausalLM.from_pretrained(
43
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
46
- # Language detector
47
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
48
  lang_model_path = hf_hub_download(
49
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -58,10 +57,7 @@ def detect_language(text: str, top_k: int = 1):
58
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
59
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
60
 
61
-
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
-
64
- # Load tokenizer and model
65
  translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
66
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
67
  config.TRANSLATION_MODEL_NAME,
@@ -69,16 +65,15 @@ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
69
  device_map="auto" if DEVICE == "cuda" else None
70
  )
71
 
72
- print(" Translation model loaded successfully")
73
-
74
 
75
  LANG_CODE_MAP = {
76
- "eng_Latn": "eng_Latn", # English
77
- "ibo_Latn": "ibo_Latn", # Igbo
78
- "yor_Latn": "yor_Latn", # Yoruba
79
- "hau_Latn": "hau_Latn", # Hausa
80
- "swh_Latn": "swa_Latn", # Swahili
81
- "amh_Latn": "amh_Ethi", # Amharic
82
  }
83
 
84
  SUPPORTED_LANGS = {
@@ -90,7 +85,6 @@ SUPPORTED_LANGS = {
90
  "amh_Latn": "Amharic",
91
  }
92
 
93
- # Text chunking
94
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
95
 
96
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -112,10 +106,6 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
112
  return chunks
113
 
114
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
115
- """
116
- Translate text using
117
- IMPORTANT: Model expects format "src_lang text" -> "tgt_lang translation"
118
- """
119
  print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
120
  print(f" Input: {text[:100]}...")
121
 
@@ -123,35 +113,26 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
123
  print(" No translation needed (same language)")
124
  return text
125
 
126
- # Get language codes
127
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
128
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
129
 
130
  print(f" Using codes: {src_code} → {tgt_code}")
131
 
132
-
133
  if src_code != "eng_Latn" and tgt_code != "eng_Latn":
134
  print(f" WARNING: Model wasn't trained on {src_code}→{tgt_code}")
135
  print(f" Will translate {src_code}→eng_Latn→{tgt_code}")
136
-
137
-
138
  to_english = translate_text_single(text, src_code, "eng_Latn", max_chunk_len)
139
  return translate_text_single(to_english, "eng_Latn", tgt_code, max_chunk_len)
140
 
141
-
142
  return translate_text_single(text, src_code, tgt_code, max_chunk_len)
143
 
144
  def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
145
- """
146
- Perform single translation step
147
- """
148
-
149
- supported_pairs = [
150
  ("eng_Latn", "ibo_Latn"), ("ibo_Latn", "eng_Latn"),
151
  ("eng_Latn", "yor_Latn"), ("yor_Latn", "eng_Latn"),
152
  ("eng_Latn", "hau_Latn"), ("hau_Latn", "eng_Latn"),
153
- ("eng_Latn", "swa_Latn"), ("swa_Latn", "eng_Latn"),
154
- ("eng_Latn", "amh_Ethi"), ("amh_Ethi", "eng_Latn"),
155
  ]
156
 
157
  if (src_code, tgt_code) not in supported_pairs:
@@ -164,10 +145,8 @@ def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len
164
  print(f" Chunk {i+1}/{len(chunks)}: '{chunk[:50]}...'")
165
 
166
  try:
167
-
168
  input_text = f"{src_code} {chunk}"
169
 
170
- # Tokenize
171
  inputs = translation_tokenizer(
172
  input_text,
173
  return_tensors="pt",
@@ -178,7 +157,6 @@ def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len
178
  if DEVICE == "cuda":
179
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
180
 
181
- # Generate translation
182
  generated_tokens = translation_model.generate(
183
  **inputs,
184
  max_new_tokens=400,
@@ -186,13 +164,11 @@ def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len
186
  early_stopping=True
187
  )
188
 
189
- # Decode
190
  result = translation_tokenizer.batch_decode(
191
  generated_tokens,
192
  skip_special_tokens=True
193
  )[0]
194
 
195
-
196
  if result.startswith(tgt_code + " "):
197
  result = result[len(tgt_code) + 1:]
198
 
@@ -201,13 +177,12 @@ def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len
201
 
202
  except Exception as e:
203
  print(f" ERROR: {e}")
204
- translated_parts.append(chunk) # Return original as fallback
205
 
206
  final_result = " ".join(translated_parts).strip()
207
  print(f" Final: '{final_result[:100]}...'")
208
  return final_result
209
 
210
- # RAG retrieval
211
  def retrieve_docs(query: str, vs_path: str):
212
  if not vs_path or not os.path.exists(vs_path):
213
  return None
@@ -296,14 +271,12 @@ def run_pipeline(user_query: str, session_id: str = None):
296
  if session_id is None:
297
  session_id = str(uuid.uuid4())
298
 
299
- # Language detection
300
  lang_label, prob = detect_language(user_query, top_k=1)[0]
301
  if lang_label not in SUPPORTED_LANGS:
302
  lang_label = "eng_Latn"
303
 
304
  print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
305
 
306
-
307
  translated_query = (
308
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
309
  if lang_label != "eng_Latn"
@@ -314,7 +287,6 @@ def run_pipeline(user_query: str, session_id: str = None):
314
 
315
  intent, extra = detect_intent(translated_query)
316
 
317
- # Load conversation history
318
  history = memory_store.get_history(session_id) or []
319
  if len(history) > MAX_HISTORY_MESSAGES:
320
  history = history[-MAX_HISTORY_MESSAGES:]
@@ -347,13 +319,11 @@ def run_pipeline(user_query: str, session_id: str = None):
347
  messages_for_qwen = build_messages_from_history(history, system_prompt)
348
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
349
 
350
-
351
  history.append({"role": "assistant", "content": english_answer})
352
  if len(history) > MAX_HISTORY_MESSAGES:
353
  history = history[-MAX_HISTORY_MESSAGES:]
354
  memory_store.save_history(session_id, history)
355
 
356
-
357
  final_answer = (
358
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
359
  if lang_label != "eng_Latn"
 
43
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
 
46
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
47
  lang_model_path = hf_hub_download(
48
  repo_id=config.LANG_ID_MODEL_REPO,
 
57
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
58
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
59
 
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
 
 
61
  translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
62
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
63
  config.TRANSLATION_MODEL_NAME,
 
65
  device_map="auto" if DEVICE == "cuda" else None
66
  )
67
 
68
+ print("Translation model loaded successfully")
 
69
 
70
  LANG_CODE_MAP = {
71
+ "eng_Latn": "eng_Latn",
72
+ "ibo_Latn": "ibo_Latn",
73
+ "yor_Latn": "yor_Latn",
74
+ "hau_Latn": "hau_Latn",
75
+ "swh_Latn": "swa_Latn",
76
+ "amh_Latn": "amh_Ethi",
77
  }
78
 
79
  SUPPORTED_LANGS = {
 
85
  "amh_Latn": "Amharic",
86
  }
87
 
 
88
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
89
 
90
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
106
  return chunks
107
 
108
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
 
 
 
 
109
  print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
110
  print(f" Input: {text[:100]}...")
111
 
 
113
  print(" No translation needed (same language)")
114
  return text
115
 
 
116
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
117
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
118
 
119
  print(f" Using codes: {src_code} → {tgt_code}")
120
 
 
121
  if src_code != "eng_Latn" and tgt_code != "eng_Latn":
122
  print(f" WARNING: Model wasn't trained on {src_code}→{tgt_code}")
123
  print(f" Will translate {src_code}→eng_Latn→{tgt_code}")
 
 
124
  to_english = translate_text_single(text, src_code, "eng_Latn", max_chunk_len)
125
  return translate_text_single(to_english, "eng_Latn", tgt_code, max_chunk_len)
126
 
 
127
  return translate_text_single(text, src_code, tgt_code, max_chunk_len)
128
 
129
  def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
130
+ supported_pairs = [
 
 
 
 
131
  ("eng_Latn", "ibo_Latn"), ("ibo_Latn", "eng_Latn"),
132
  ("eng_Latn", "yor_Latn"), ("yor_Latn", "eng_Latn"),
133
  ("eng_Latn", "hau_Latn"), ("hau_Latn", "eng_Latn"),
134
+ ("eng_Latn", "swa_Latn"), ("swa_Latn", "eng_Latn"),
135
+ ("eng_Latn", "amh_Ethi"), ("amh_Ethi", "eng_Latn"),
136
  ]
137
 
138
  if (src_code, tgt_code) not in supported_pairs:
 
145
  print(f" Chunk {i+1}/{len(chunks)}: '{chunk[:50]}...'")
146
 
147
  try:
 
148
  input_text = f"{src_code} {chunk}"
149
 
 
150
  inputs = translation_tokenizer(
151
  input_text,
152
  return_tensors="pt",
 
157
  if DEVICE == "cuda":
158
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
159
 
 
160
  generated_tokens = translation_model.generate(
161
  **inputs,
162
  max_new_tokens=400,
 
164
  early_stopping=True
165
  )
166
 
 
167
  result = translation_tokenizer.batch_decode(
168
  generated_tokens,
169
  skip_special_tokens=True
170
  )[0]
171
 
 
172
  if result.startswith(tgt_code + " "):
173
  result = result[len(tgt_code) + 1:]
174
 
 
177
 
178
  except Exception as e:
179
  print(f" ERROR: {e}")
180
+ translated_parts.append(chunk)
181
 
182
  final_result = " ".join(translated_parts).strip()
183
  print(f" Final: '{final_result[:100]}...'")
184
  return final_result
185
 
 
186
  def retrieve_docs(query: str, vs_path: str):
187
  if not vs_path or not os.path.exists(vs_path):
188
  return None
 
271
  if session_id is None:
272
  session_id = str(uuid.uuid4())
273
 
 
274
  lang_label, prob = detect_language(user_query, top_k=1)[0]
275
  if lang_label not in SUPPORTED_LANGS:
276
  lang_label = "eng_Latn"
277
 
278
  print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
279
 
 
280
  translated_query = (
281
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
282
  if lang_label != "eng_Latn"
 
287
 
288
  intent, extra = detect_intent(translated_query)
289
 
 
290
  history = memory_store.get_history(session_id) or []
291
  if len(history) > MAX_HISTORY_MESSAGES:
292
  history = history[-MAX_HISTORY_MESSAGES:]
 
319
  messages_for_qwen = build_messages_from_history(history, system_prompt)
320
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
321
 
 
322
  history.append({"role": "assistant", "content": english_answer})
323
  if len(history) > MAX_HISTORY_MESSAGES:
324
  history = history[-MAX_HISTORY_MESSAGES:]
325
  memory_store.save_history(session_id, history)
326
 
 
327
  final_answer = (
328
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
329
  if lang_label != "eng_Latn"