drrobot9 commited on
Commit
527b3c5
·
1 Parent(s): b5c61f9

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +84 -44
app/agents/crew_pipeline.py CHANGED
@@ -59,30 +59,36 @@ def detect_language(text: str, top_k: int = 1):
59
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
 
62
- # SIMPLIFIED: Directly load the NLLB model
63
- translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
64
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
65
- config.TRANSLATION_MODEL_NAME,
66
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
67
- device_map="auto" if DEVICE == "cuda" else None
68
- )
69
-
70
- print(f"✓ Translation model loaded: {config.TRANSLATION_MODEL_NAME}")
71
-
72
- # Verify language codes are available
73
- if hasattr(translation_tokenizer, 'lang_code_to_id'):
74
- print(f"Available language codes in tokenizer:")
75
- # Show languages we care about
76
- target_langs = ["eng_Latn", "ibo_Latn", "yor_Latn", "hau_Latn", "swa_Latn", "amh_Ethi"]
77
- for lang in target_langs:
78
- if lang in translation_tokenizer.lang_code_to_id:
79
- print(f" ✓ {lang}")
80
- else:
81
- print(f" ✗ {lang} (not found)")
82
- else:
83
- print("Warning: Tokenizer doesn't have lang_code_to_id attribute")
84
 
85
- # Correct language code mapping for NLLB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  LANG_CODE_MAP = {
87
  "eng_Latn": "eng_Latn", # English
88
  "ibo_Latn": "ibo_Latn", # Igbo
@@ -92,6 +98,16 @@ LANG_CODE_MAP = {
92
  "amh_Latn": "amh_Ethi", # Amharic
93
  }
94
 
 
 
 
 
 
 
 
 
 
 
95
  SUPPORTED_LANGS = {
96
  "eng_Latn": "English",
97
  "ibo_Latn": "Igbo",
@@ -129,43 +145,66 @@ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int =
129
  print(" No translation needed (same language)")
130
  return text
131
 
 
132
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
133
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
134
 
135
  print(f" Using codes: {src_code} → {tgt_code}")
136
 
137
- # Check if codes are available
138
  if not hasattr(translation_tokenizer, 'lang_code_to_id'):
139
- print(" ERROR: Tokenizer doesn't support language codes")
140
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  if src_code not in translation_tokenizer.lang_code_to_id:
143
- print(f" WARNING: Source code {src_code} not found, using English")
144
- src_code = "eng_Latn"
 
145
 
146
  if tgt_code not in translation_tokenizer.lang_code_to_id:
147
- print(f" WARNING: Target code {tgt_code} not found, using English")
148
- tgt_code = "eng_Latn"
149
 
150
- # Handle non-English to non-English translation
151
- if src_code != "eng_Latn" and tgt_code != "eng_Latn":
152
- print(f" Two-step translation: {src_code}→eng_Latn→{tgt_code}")
153
- to_english = translate_direct(text, src_code, "eng_Latn", max_chunk_len)
154
- return translate_direct(to_english, "eng_Latn", tgt_code, max_chunk_len)
155
 
156
- return translate_direct(text, src_code, tgt_code, max_chunk_len)
157
-
158
- def translate_direct(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
159
- """Direct translation using forced_bos_token_id"""
160
- chunks = chunk_text(text, max_len=max_chunk_len)
161
- translated_parts = []
162
-
163
- # Set source language
164
  translation_tokenizer.src_lang = src_code
165
 
166
- # Get forced bos token ID
167
  forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
168
 
 
 
 
169
  for i, chunk in enumerate(chunks):
170
  try:
171
  inputs = translation_tokenizer(
@@ -200,6 +239,7 @@ def translate_direct(text: str, src_code: str, tgt_code: str, max_chunk_len: int
200
 
201
  return " ".join(translated_parts).strip()
202
 
 
203
  def retrieve_docs(query: str, vs_path: str):
204
  if not vs_path or not os.path.exists(vs_path):
205
  return None
 
59
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
 
62
+ NLLB_MODEL = "facebook/nllb-200-distilled-600M"
63
+ print(f"Using model: {NLLB_MODEL}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ try:
66
+ translation_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
67
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
68
+ NLLB_MODEL,
69
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
70
+ device_map="auto" if DEVICE == "cuda" else None
71
+ )
72
+
73
+ print(f"✓ Translation model loaded successfully")
74
+
75
+ # DEBUG: Check tokenizer properties
76
+ print(f"Tokenizer type: {type(translation_tokenizer).__name__}")
77
+ print(f"Has lang_code_to_id: {hasattr(translation_tokenizer, 'lang_code_to_id')}")
78
+
79
+ if hasattr(translation_tokenizer, 'lang_code_to_id'):
80
+ print(f"Sample language codes: {list(translation_tokenizer.lang_code_to_id.keys())[:10]}")
81
+ else:
82
+
83
+ from transformers import AutoConfig
84
+ config_model = AutoConfig.from_pretrained(NLLB_MODEL)
85
+ print(f"Model config: {config_model}")
86
+
87
+ except Exception as e:
88
+ print(f"✗ Error loading translation model: {e}")
89
+ raise
90
+
91
+ # Language code mapping
92
  LANG_CODE_MAP = {
93
  "eng_Latn": "eng_Latn", # English
94
  "ibo_Latn": "ibo_Latn", # Igbo
 
98
  "amh_Latn": "amh_Ethi", # Amharic
99
  }
100
 
101
+ # Alternative mapping k
102
+ LANG_CODE_MAP_ALT = {
103
+ "eng_Latn": "en", # English
104
+ "ibo_Latn": "ig", # Igbo
105
+ "yor_Latn": "yo", # Yoruba
106
+ "hau_Latn": "ha", # Hausa
107
+ "swh_Latn": "sw", # Swahili
108
+ "amh_Latn": "am", # Amharic
109
+ }
110
+
111
  SUPPORTED_LANGS = {
112
  "eng_Latn": "English",
113
  "ibo_Latn": "Igbo",
 
145
  print(" No translation needed (same language)")
146
  return text
147
 
148
+
149
  src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
150
  tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
151
 
152
  print(f" Using codes: {src_code} → {tgt_code}")
153
 
154
+
155
  if not hasattr(translation_tokenizer, 'lang_code_to_id'):
156
+ print(" WARNING: Tokenizer doesn't have lang_code_to_id")
157
+ print(" Trying alternative method...")
158
+
159
+
160
+ src_code_alt = LANG_CODE_MAP_ALT.get(src_lang, "en")
161
+ tgt_code_alt = LANG_CODE_MAP_ALT.get(tgt_lang, "en")
162
+
163
+
164
+ try:
165
+ from transformers import pipeline
166
+ translator = pipeline(
167
+ "translation",
168
+ model=translation_model,
169
+ tokenizer=translation_tokenizer,
170
+ src_lang=src_code_alt,
171
+ tgt_lang=tgt_code_alt,
172
+ device=0 if DEVICE == "cuda" else -1,
173
+ max_length=400
174
+ )
175
+
176
+ chunks = chunk_text(text, max_len=max_chunk_len)
177
+ translated_parts = []
178
+
179
+ for chunk in chunks:
180
+ result = translator(chunk)
181
+ translated_parts.append(result[0]["translation_text"])
182
+
183
+ return " ".join(translated_parts).strip()
184
+
185
+ except Exception as e:
186
+ print(f" Pipeline translation failed: {e}")
187
+ return text
188
+
189
 
190
  if src_code not in translation_tokenizer.lang_code_to_id:
191
+ print(f" WARNING: Source code {src_code} not found, trying alternatives...")
192
+
193
+ src_code = LANG_CODE_MAP_ALT.get(src_lang, "eng_Latn")
194
 
195
  if tgt_code not in translation_tokenizer.lang_code_to_id:
196
+ print(f" WARNING: Target code {tgt_code} not found, trying alternatives...")
197
+ tgt_code = LANG_CODE_MAP_ALT.get(tgt_lang, "eng_Latn")
198
 
 
 
 
 
 
199
 
 
 
 
 
 
 
 
 
200
  translation_tokenizer.src_lang = src_code
201
 
202
+
203
  forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
204
 
205
+ chunks = chunk_text(text, max_len=max_chunk_len)
206
+ translated_parts = []
207
+
208
  for i, chunk in enumerate(chunks):
209
  try:
210
  inputs = translation_tokenizer(
 
239
 
240
  return " ".join(translated_parts).strip()
241
 
242
+
243
  def retrieve_docs(query: str, vs_path: str):
244
  if not vs_path or not os.path.exists(vs_path):
245
  return None