drrobot9 commited on
Commit
b5c61f9
·
1 Parent(s): 0a1fadc

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +82 -97
app/agents/crew_pipeline.py CHANGED
@@ -59,13 +59,37 @@ def detect_language(text: str, top_k: int = 1):
59
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  LANG_CODE_MAP = {
63
- "eng_Latn": "eng_Latn",
64
- "ibo_Latn": "ibo_Latn",
65
- "yor_Latn": "yor_Latn",
66
- "hau_Latn": "hau_Latn",
67
- "swh_Latn": "swa_Latn",
68
- "amh_Latn": "amh_Ethi",
69
  }
70
 
71
  SUPPORTED_LANGS = {
@@ -97,96 +121,77 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
97
  chunks.append(current.strip())
98
  return chunks
99
 
100
- def load_translation_model():
101
- """Load translation model with proper configuration"""
102
- try:
103
- tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
104
- model = AutoModelForSeq2SeqLM.from_pretrained(
105
- config.TRANSLATION_MODEL_NAME,
106
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
107
- device_map="auto" if DEVICE == "cuda" else None
108
- )
109
- print("✓ Custom translation model loaded")
110
- return tokenizer, model
111
- except Exception as e:
112
- print(f"✗ Error loading custom model: {e}")
113
- print("Loading standard NLLB model as fallback...")
114
- tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
115
- model = AutoModelForSeq2SeqLM.from_pretrained(
116
- "facebook/nllb-200-distilled-600M",
117
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
118
- device_map="auto" if DEVICE == "cuda" else None
119
- )
120
- print("✓ Standard NLLB model loaded as fallback")
121
- return tokenizer, model
122
-
123
- # Load the model
124
- translation_tokenizer, translation_model = load_translation_model()
125
-
126
- def translate_with_nllb(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
127
- """Translate using NLLB model with forced_bos_token_id"""
128
- chunks = chunk_text(text, max_len=max_chunk_len)
129
- translated_parts = []
130
 
131
- # Check if tokenizer has lang_code_to_id
132
- if hasattr(translation_tokenizer, 'lang_code_to_id'):
133
- try:
134
- # Set source language
135
- translation_tokenizer.src_lang = src_code
136
- # Get forced bos token ID
137
- forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
138
-
139
- for i, chunk in enumerate(chunks):
140
- try:
141
- inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
142
-
143
- if DEVICE == "cuda":
144
- inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
145
-
146
- generated_tokens = translation_model.generate(
147
- **inputs,
148
- forced_bos_token_id=forced_bos_token_id,
149
- max_new_tokens=400,
150
- num_beams=4,
151
- early_stopping=True
152
- )
153
-
154
- result = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
155
- translated_parts.append(result.strip())
156
-
157
- except Exception as e:
158
- print(f" Chunk {i+1} error: {e}")
159
- translated_parts.append(chunk)
160
- except Exception as e:
161
- print(f" Language code error: {e}")
162
- # Fallback to simple translation
163
- return translate_simple(text, max_chunk_len)
164
- else:
165
- # If no lang_code_to_id, try simple translation
166
- return translate_simple(text, max_chunk_len)
167
 
168
- return " ".join(translated_parts).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- def translate_simple(text: str, max_chunk_len: int = 400) -> str:
171
- """Simple translation without language codes"""
172
  chunks = chunk_text(text, max_len=max_chunk_len)
173
  translated_parts = []
174
 
 
 
 
 
 
 
175
  for i, chunk in enumerate(chunks):
176
  try:
177
- inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
178
 
179
  if DEVICE == "cuda":
180
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
181
 
182
  generated_tokens = translation_model.generate(
183
  **inputs,
 
184
  max_new_tokens=400,
185
  num_beams=4,
186
  early_stopping=True
187
  )
188
 
189
- result = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
 
190
  translated_parts.append(result.strip())
191
 
192
  except Exception as e:
@@ -195,26 +200,6 @@ def translate_simple(text: str, max_chunk_len: int = 400) -> str:
195
 
196
  return " ".join(translated_parts).strip()
197
 
198
- def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
199
- print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
200
- print(f" Input: {text[:100]}...")
201
-
202
- if not text.strip() or src_lang == tgt_lang:
203
- print(" No translation needed (same language)")
204
- return text
205
-
206
- src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
207
- tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
208
-
209
- print(f" Using codes: {src_code} → {tgt_code}")
210
-
211
- if src_code != "eng_Latn" and tgt_code != "eng_Latn":
212
- print(f" Two-step translation: {src_code}→eng_Latn→{tgt_code}")
213
- to_english = translate_with_nllb(text, src_code, "eng_Latn", max_chunk_len)
214
- return translate_with_nllb(to_english, "eng_Latn", tgt_code, max_chunk_len)
215
-
216
- return translate_with_nllb(text, src_code, tgt_code, max_chunk_len)
217
-
218
  def retrieve_docs(query: str, vs_path: str):
219
  if not vs_path or not os.path.exists(vs_path):
220
  return None
 
59
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
 
62
+ # SIMPLIFIED: Directly load the NLLB model
63
+ translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
64
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
65
+ config.TRANSLATION_MODEL_NAME,
66
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
67
+ device_map="auto" if DEVICE == "cuda" else None
68
+ )
69
+
70
+ print(f"✓ Translation model loaded: {config.TRANSLATION_MODEL_NAME}")
71
+
72
+ # Verify language codes are available
73
+ if hasattr(translation_tokenizer, 'lang_code_to_id'):
74
+ print(f"Available language codes in tokenizer:")
75
+ # Show languages we care about
76
+ target_langs = ["eng_Latn", "ibo_Latn", "yor_Latn", "hau_Latn", "swa_Latn", "amh_Ethi"]
77
+ for lang in target_langs:
78
+ if lang in translation_tokenizer.lang_code_to_id:
79
+ print(f" ✓ {lang}")
80
+ else:
81
+ print(f" ✗ {lang} (not found)")
82
+ else:
83
+ print("Warning: Tokenizer doesn't have lang_code_to_id attribute")
84
+
85
+ # Correct language code mapping for NLLB
86
  LANG_CODE_MAP = {
87
+ "eng_Latn": "eng_Latn", # English
88
+ "ibo_Latn": "ibo_Latn", # Igbo
89
+ "yor_Latn": "yor_Latn", # Yoruba
90
+ "hau_Latn": "hau_Latn", # Hausa
91
+ "swh_Latn": "swa_Latn", # Swahili
92
+ "amh_Latn": "amh_Ethi", # Amharic
93
  }
94
 
95
  SUPPORTED_LANGS = {
 
121
  chunks.append(current.strip())
122
  return chunks
123
 
124
+ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
125
+ print(f"\n[TRANSLATION] {src_lang} {tgt_lang}")
126
+ print(f" Input: {text[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ if not text.strip() or src_lang == tgt_lang:
129
+ print(" No translation needed (same language)")
130
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
133
+ tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
134
+
135
+ print(f" Using codes: {src_code} → {tgt_code}")
136
+
137
+ # Check if codes are available
138
+ if not hasattr(translation_tokenizer, 'lang_code_to_id'):
139
+ print(" ERROR: Tokenizer doesn't support language codes")
140
+ return text
141
+
142
+ if src_code not in translation_tokenizer.lang_code_to_id:
143
+ print(f" WARNING: Source code {src_code} not found, using English")
144
+ src_code = "eng_Latn"
145
+
146
+ if tgt_code not in translation_tokenizer.lang_code_to_id:
147
+ print(f" WARNING: Target code {tgt_code} not found, using English")
148
+ tgt_code = "eng_Latn"
149
+
150
+ # Handle non-English to non-English translation
151
+ if src_code != "eng_Latn" and tgt_code != "eng_Latn":
152
+ print(f" Two-step translation: {src_code}→eng_Latn→{tgt_code}")
153
+ to_english = translate_direct(text, src_code, "eng_Latn", max_chunk_len)
154
+ return translate_direct(to_english, "eng_Latn", tgt_code, max_chunk_len)
155
+
156
+ return translate_direct(text, src_code, tgt_code, max_chunk_len)
157
 
158
+ def translate_direct(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
159
+ """Direct translation using forced_bos_token_id"""
160
  chunks = chunk_text(text, max_len=max_chunk_len)
161
  translated_parts = []
162
 
163
+ # Set source language
164
+ translation_tokenizer.src_lang = src_code
165
+
166
+ # Get forced bos token ID
167
+ forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
168
+
169
  for i, chunk in enumerate(chunks):
170
  try:
171
+ inputs = translation_tokenizer(
172
+ chunk,
173
+ return_tensors="pt",
174
+ truncation=True,
175
+ max_length=512
176
+ )
177
 
178
  if DEVICE == "cuda":
179
  inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
180
 
181
  generated_tokens = translation_model.generate(
182
  **inputs,
183
+ forced_bos_token_id=forced_bos_token_id,
184
  max_new_tokens=400,
185
  num_beams=4,
186
  early_stopping=True
187
  )
188
 
189
+ result = translation_tokenizer.batch_decode(
190
+ generated_tokens,
191
+ skip_special_tokens=True
192
+ )[0]
193
+
194
+ print(f" Chunk {i+1}: '{chunk[:30]}...' → '{result[:30]}...'")
195
  translated_parts.append(result.strip())
196
 
197
  except Exception as e:
 
200
 
201
  return " ".join(translated_parts).strip()
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def retrieve_docs(query: str, vs_path: str):
204
  if not vs_path or not os.path.exists(vs_path):
205
  return None