LiamKhoaLe commited on
Commit
f359dc2
·
1 Parent(s): e2fd976

Upd SFT trans saver

Browse files
Files changed (4) hide show
  1. app.py +15 -1
  2. utils/rag.py +5 -1
  3. vi/processing.py +71 -14
  4. vi/translator.py +14 -0
app.py CHANGED
@@ -390,12 +390,20 @@ def _run_job(dataset_key: str, params: ProcessParams):
390
  if params.vietnamese_translation:
391
  set_state(message="Loading Vietnamese translator", progress=0.05)
392
  try:
 
 
 
 
 
393
  vietnamese_translator.load_model()
394
  translator = vietnamese_translator
395
  logger.info("✅ Vietnamese translator loaded successfully")
396
  except Exception as e:
397
  logger.error(f"❌ Failed to load Vietnamese translator: {e}")
398
- set_state(message=f"Warning: Vietnamese translation failed - {e}", progress=0.1)
 
 
 
399
 
400
  if params.rag_processing:
401
  # RAG processing mode
@@ -429,6 +437,12 @@ def _run_job(dataset_key: str, params: ProcessParams):
429
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
430
  translator=translator
431
  )
 
 
 
 
 
 
432
  logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
433
  writer.close()
434
 
 
390
  if params.vietnamese_translation:
391
  set_state(message="Loading Vietnamese translator", progress=0.05)
392
  try:
393
+ # Ensure cache directories are set up properly
394
+ cache_dir = os.path.abspath("cache/huggingface")
395
+ os.makedirs(cache_dir, exist_ok=True)
396
+ os.environ["HF_HOME"] = cache_dir
397
+
398
  vietnamese_translator.load_model()
399
  translator = vietnamese_translator
400
  logger.info("✅ Vietnamese translator loaded successfully")
401
  except Exception as e:
402
  logger.error(f"❌ Failed to load Vietnamese translator: {e}")
403
+ logger.warning("Continuing without Vietnamese translation...")
404
+ set_state(message=f"Warning: Vietnamese translation disabled - {e}", progress=0.1)
405
+ # Don't fail the entire job, just disable translation
406
+ translator = None
407
 
408
  if params.rag_processing:
409
  # RAG processing mode
 
437
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
438
  translator=translator
439
  )
440
+ # Log translation statistics if translator was used
441
+ if translator and hasattr(translator, 'get_stats'):
442
+ translation_stats = translator.get_stats()
443
+ logger.info(f"[JOB] Translation stats: {translation_stats}")
444
+ stats["translation_stats"] = translation_stats
445
+
446
  logger.info(f"[JOB] Processed dataset={dataset_key} rows={count} stats={stats}")
447
  writer.close()
448
 
utils/rag.py CHANGED
@@ -309,9 +309,13 @@ class RAGProcessor:
309
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
310
  try:
311
  row = translate_rag_row(row, translator, ["question", "answer", "context"])
312
- row["vi_translated"] = True
 
 
 
313
  except Exception as e:
314
  logger.error(f"Failed to translate RAG row: {e}")
 
315
 
316
  writer.write(row)
317
  stats["written"] = stats.get("written", 0) + 1
 
309
  if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
310
  try:
311
  row = translate_rag_row(row, translator, ["question", "answer", "context"])
312
+ # Add translation metadata
313
+ if "meta" not in row:
314
+ row["meta"] = {}
315
+ row["meta"]["vietnamese_translated"] = True
316
  except Exception as e:
317
  logger.error(f"Failed to translate RAG row: {e}")
318
+ # Continue with original row if translation fails
319
 
320
  writer.write(row)
321
  stats["written"] = stats.get("written", 0) + 1
vi/processing.py CHANGED
@@ -114,11 +114,17 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
114
  if _validate_vi_translation(original, translated):
115
  translated_sft[field] = _vi_sanitize_text(translated)
116
  logger.debug(f"✅ Successfully translated field '{field}'")
 
 
 
117
  else:
118
  logger.warning(f"❌ Invalid Vietnamese translation for field {field}, keeping original")
119
  logger.warning(f" Original: '{original[:50]}...'")
120
  logger.warning(f" Translated: '{translated[:50]}...'")
121
  translated_sft[field] = original
 
 
 
122
  except Exception as e:
123
  logger.error(f"Failed to translate field '{field}': {e}")
124
  translated_sft[field] = sft_data[field]
@@ -156,17 +162,44 @@ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] =
156
  text_fields = ["question", "answer", "context"]
157
 
158
  try:
159
- translated_row = translator.translate_dict(row, text_fields)
160
- # Validate and sanitize translated fields
161
- for f in text_fields:
162
- if f in translated_row:
163
- original = row.get(f, "")
164
- translated = translated_row[f]
165
- if _validate_vi_translation(original, translated):
166
- translated_row[f] = _vi_sanitize_text(translated)
167
- else:
168
- logger.warning(f"Invalid Vietnamese translation for field {f}, keeping original")
169
- translated_row[f] = original
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  logger.debug(f"Translated RAG row with fields: {text_fields}")
171
  return translated_row
172
  except Exception as e:
@@ -187,8 +220,12 @@ def should_translate(vietnamese_translation: bool, translator) -> bool:
187
  if not vietnamese_translation:
188
  return False
189
 
190
- if not translator or not translator.is_loaded():
191
- logger.warning("Vietnamese translation requested but translator not available")
 
 
 
 
192
  return False
193
 
194
  return True
@@ -202,4 +239,24 @@ def log_translation_stats(stats: Dict[str, Any], translated_count: int) -> None:
202
  translated_count: Number of items translated
203
  """
204
  stats["vietnamese_translated"] = translated_count
205
- logger.info(f"Vietnamese translation completed: {translated_count} items translated")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if _validate_vi_translation(original, translated):
115
  translated_sft[field] = _vi_sanitize_text(translated)
116
  logger.debug(f"✅ Successfully translated field '{field}'")
117
+ # Add success statistics if stats available
118
+ if hasattr(translator, '_stats'):
119
+ add_translation_stats(translator._stats, f"sft_{field}", True)
120
  else:
121
  logger.warning(f"❌ Invalid Vietnamese translation for field {field}, keeping original")
122
  logger.warning(f" Original: '{original[:50]}...'")
123
  logger.warning(f" Translated: '{translated[:50]}...'")
124
  translated_sft[field] = original
125
+ # Add failure statistics if stats available
126
+ if hasattr(translator, '_stats'):
127
+ add_translation_stats(translator._stats, f"sft_{field}", False)
128
  except Exception as e:
129
  logger.error(f"Failed to translate field '{field}': {e}")
130
  translated_sft[field] = sft_data[field]
 
162
  text_fields = ["question", "answer", "context"]
163
 
164
  try:
165
+ # Create a copy of the row to avoid modifying the original
166
+ translated_row = row.copy()
167
+
168
+ # Translate each field individually with proper validation
169
+ for field in text_fields:
170
+ if field in row and isinstance(row[field], str) and row[field].strip():
171
+ try:
172
+ original = row[field]
173
+ translated = translator.translate_text(original)
174
+
175
+ # Debug logging
176
+ logger.debug(f"RAG Translation attempt for field '{field}':")
177
+ logger.debug(f" Original: '{original[:50]}...'")
178
+ logger.debug(f" Translated: '{translated[:50]}...'")
179
+ logger.debug(f" Are they the same? {original == translated}")
180
+
181
+ # Validate and sanitize translated field
182
+ if _validate_vi_translation(original, translated):
183
+ translated_row[field] = _vi_sanitize_text(translated)
184
+ logger.debug(f"✅ Successfully translated RAG field '{field}'")
185
+ # Add success statistics if stats available
186
+ if hasattr(translator, '_stats'):
187
+ add_translation_stats(translator._stats, f"rag_{field}", True)
188
+ else:
189
+ logger.warning(f"❌ Invalid Vietnamese translation for RAG field {field}, keeping original")
190
+ logger.warning(f" Original: '{original[:50]}...'")
191
+ logger.warning(f" Translated: '{translated[:50]}...'")
192
+ translated_row[field] = original
193
+ # Add failure statistics if stats available
194
+ if hasattr(translator, '_stats'):
195
+ add_translation_stats(translator._stats, f"rag_{field}", False)
196
+ except Exception as e:
197
+ logger.error(f"Failed to translate RAG field '{field}': {e}")
198
+ translated_row[field] = row[field]
199
+ else:
200
+ # Keep original if field doesn't exist or is empty
201
+ translated_row[field] = row.get(field, "")
202
+
203
  logger.debug(f"Translated RAG row with fields: {text_fields}")
204
  return translated_row
205
  except Exception as e:
 
220
  if not vietnamese_translation:
221
  return False
222
 
223
+ if not translator:
224
+ logger.warning("Vietnamese translation requested but translator is None")
225
+ return False
226
+
227
+ if not hasattr(translator, 'is_loaded') or not translator.is_loaded():
228
+ logger.warning("Vietnamese translation requested but translator not loaded")
229
  return False
230
 
231
  return True
 
239
  translated_count: Number of items translated
240
  """
241
  stats["vietnamese_translated"] = translated_count
242
+ logger.info(f"Vietnamese translation completed: {translated_count} items translated")
243
+
244
+ def add_translation_stats(stats: Dict[str, Any], field: str, success: bool) -> None:
245
+ """
246
+ Add translation statistics for individual fields.
247
+
248
+ Args:
249
+ stats: Statistics dictionary to update
250
+ field: Field name that was translated
251
+ success: Whether translation was successful
252
+ """
253
+ if "translation_stats" not in stats:
254
+ stats["translation_stats"] = {}
255
+
256
+ if field not in stats["translation_stats"]:
257
+ stats["translation_stats"][field] = {"success": 0, "failed": 0}
258
+
259
+ if success:
260
+ stats["translation_stats"][field]["success"] += 1
261
+ else:
262
+ stats["translation_stats"][field]["failed"] += 1
vi/translator.py CHANGED
@@ -31,6 +31,7 @@ class VietnameseTranslator:
31
  self.model = None
32
  self.tokenizer = None
33
  self._is_loaded = False
 
34
 
35
  logger.info(f"VietnameseTranslator initialized with model: {self.model_name}")
36
  logger.info(f"Using device: {self.device}")
@@ -101,6 +102,8 @@ class VietnameseTranslator:
101
  return text
102
 
103
  try:
 
 
104
  # Prepare input with target language token
105
  # The model requires a target language token in the format >>id<<
106
  input_text = f">>vie<< {text.strip()}"
@@ -130,10 +133,13 @@ class VietnameseTranslator:
130
  logger.debug(f"Translation result: '{text[:50]}...' -> '{translated[:50]}...'")
131
  logger.debug(f"Are original and translated the same? {text.strip() == translated.strip()}")
132
 
 
 
133
  return translated.strip()
134
 
135
  except Exception as e:
136
  logger.error(f"Translation failed for text: '{text[:100]}...' - Error: {e}")
 
137
  # Return original text if translation fails
138
  return text
139
 
@@ -270,3 +276,11 @@ class VietnameseTranslator:
270
  "device": self.device,
271
  "is_loaded": self._is_loaded
272
  }
 
 
 
 
 
 
 
 
 
31
  self.model = None
32
  self.tokenizer = None
33
  self._is_loaded = False
34
+ self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0}
35
 
36
  logger.info(f"VietnameseTranslator initialized with model: {self.model_name}")
37
  logger.info(f"Using device: {self.device}")
 
102
  return text
103
 
104
  try:
105
+ self._stats["total_translations"] += 1
106
+
107
  # Prepare input with target language token
108
  # The model requires a target language token in the format >>id<<
109
  input_text = f">>vie<< {text.strip()}"
 
133
  logger.debug(f"Translation result: '{text[:50]}...' -> '{translated[:50]}...'")
134
  logger.debug(f"Are original and translated the same? {text.strip() == translated.strip()}")
135
 
136
+ # Track success
137
+ self._stats["successful_translations"] += 1
138
  return translated.strip()
139
 
140
  except Exception as e:
141
  logger.error(f"Translation failed for text: '{text[:100]}...' - Error: {e}")
142
+ self._stats["failed_translations"] += 1
143
  # Return original text if translation fails
144
  return text
145
 
 
276
  "device": self.device,
277
  "is_loaded": self._is_loaded
278
  }
279
+
280
+ def get_stats(self) -> Dict[str, Any]:
281
+ """Get translation statistics."""
282
+ return self._stats.copy()
283
+
284
+ def reset_stats(self) -> None:
285
+ """Reset translation statistics."""
286
+ self._stats = {"total_translations": 0, "successful_translations": 0, "failed_translations": 0}