jb100 commited on
Commit
c9eafad
·
verified ·
1 Parent(s): 46c5527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -20
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # code revision v9
2
-
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
@@ -147,7 +145,12 @@ class NLLBTranslator:
147
  if source_lang == target_lang:
148
  return text
149
 
150
- logger.info(f"Translating from {source_lang} to {target_lang}")
 
 
 
 
 
151
 
152
  # Check if simple or complex text
153
  if '\n' not in text and len(text.split('.')) <= 2:
@@ -162,8 +165,44 @@ class NLLBTranslator:
162
 
163
  except Exception as e:
164
  logger.error(f"Translation error: {str(e)}")
 
 
165
  return f"Error during translation: {str(e)}"
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def perform_translation(self, input_sentences: list, source_code: str, target_code: str, paragraph_markers: list) -> str:
168
  """Perform the actual translation using NLLB model"""
169
  batch_size = 2 # Conservative batch size for stability
@@ -174,6 +213,7 @@ class NLLBTranslator:
174
  batch_size = 1
175
 
176
  logger.info(f"Using batch size {batch_size} for average sentence length {avg_sentence_length:.1f} words")
 
177
 
178
  all_translations = []
179
 
@@ -181,7 +221,8 @@ class NLLBTranslator:
181
  batch_sentences = input_sentences[i:i + batch_size]
182
 
183
  try:
184
- # Tokenize input
 
185
  inputs = self.tokenizer(
186
  batch_sentences,
187
  return_tensors="pt",
@@ -190,15 +231,24 @@ class NLLBTranslator:
190
  max_length=512
191
  ).to(self.device)
192
 
 
 
 
 
 
 
 
193
  # Generate translation
194
  with torch.no_grad():
195
  translated_tokens = self.model.generate(
196
  **inputs,
197
- forced_bos_token_id=self.tokenizer.lang_code_to_id.get(target_code, self.tokenizer.eos_token_id),
198
  max_length=512,
199
  num_beams=4,
200
  early_stopping=True,
201
- do_sample=False
 
 
202
  )
203
 
204
  # Decode translations
@@ -207,7 +257,17 @@ class NLLBTranslator:
207
  skip_special_tokens=True
208
  )
209
 
210
- all_translations.extend(translations)
 
 
 
 
 
 
 
 
 
 
211
 
212
  # Progress logging
213
  if len(input_sentences) > 10:
@@ -217,9 +277,12 @@ class NLLBTranslator:
217
  except Exception as e:
218
  logger.error(f"Translation error in batch: {str(e)}")
219
 
220
- # Fallback: process sentences individually
221
  for single_sentence in batch_sentences:
222
  try:
 
 
 
223
  inputs = self.tokenizer(
224
  single_sentence,
225
  return_tensors="pt",
@@ -227,31 +290,47 @@ class NLLBTranslator:
227
  max_length=512
228
  ).to(self.device)
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  with torch.no_grad():
231
- translated_tokens = self.model.generate(
232
- **inputs,
233
- forced_bos_token_id=self.tokenizer.lang_code_to_id.get(target_code, self.tokenizer.eos_token_id),
234
- max_length=512,
235
- num_beams=4,
236
- early_stopping=True
237
- )
238
 
239
  translation = self.tokenizer.decode(
240
  translated_tokens[0],
241
  skip_special_tokens=True
242
  )
243
 
244
- all_translations.append(translation)
 
 
 
 
 
245
 
246
  except Exception as single_e:
247
- logger.error(f"Failed to translate sentence: {str(single_e)}")
248
- all_translations.append(f"[Translation failed for: {single_sentence[:50]}...]")
249
 
250
  # Reconstruct formatting
251
  if paragraph_markers and len(all_translations) == len(paragraph_markers):
252
  final_translation = self.reconstruct_formatting(all_translations, paragraph_markers)
253
  else:
254
- final_translation = ' '.join(all_translations) if all_translations else "Translation failed"
255
 
256
  return final_translation
257
 
@@ -803,7 +882,7 @@ def translate_document(file, source_lang: str, target_lang: str, session_id: str
803
 
804
  # Initialize translator
805
  print("Initializing NLLB Translator...")
806
- translator = NLLBTranslator(model_size="3.3B") # Use smaller model for stability
807
 
808
  # Create the Gradio app
809
  with gr.Blocks(title="NLLB Universal Translator", theme=gr.themes.Soft()) as demo:
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
145
  if source_lang == target_lang:
146
  return text
147
 
148
+ logger.info(f"Translating from {source_lang} ({source_code}) to {target_lang} ({target_code})")
149
+
150
+ # For simple test, try a direct approach first
151
+ if text.strip() == "Hello, how are you today?":
152
+ logger.info("Using simple test translation")
153
+ return self.simple_translate(text, source_code, target_code)
154
 
155
  # Check if simple or complex text
156
  if '\n' not in text and len(text.split('.')) <= 2:
 
165
 
166
  except Exception as e:
167
  logger.error(f"Translation error: {str(e)}")
168
+ import traceback
169
+ traceback.print_exc()
170
  return f"Error during translation: {str(e)}"
171
 
172
+ def simple_translate(self, text: str, source_code: str, target_code: str) -> str:
173
+ """Simple translation method for testing"""
174
+ try:
175
+ # Set source language
176
+ self.tokenizer.src_lang = source_code
177
+
178
+ # Tokenize
179
+ inputs = self.tokenizer(
180
+ text,
181
+ return_tensors="pt",
182
+ truncation=True,
183
+ max_length=512
184
+ ).to(self.device)
185
+
186
+ # Generate without forced language token first
187
+ with torch.no_grad():
188
+ outputs = self.model.generate(
189
+ **inputs,
190
+ max_length=512,
191
+ num_beams=5,
192
+ early_stopping=True,
193
+ do_sample=False
194
+ )
195
+
196
+ # Decode
197
+ translation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
198
+ logger.info(f"Simple translation result: {translation}")
199
+
200
+ return translation.strip() if translation.strip() else "Translation produced empty result"
201
+
202
+ except Exception as e:
203
+ logger.error(f"Simple translation failed: {str(e)}")
204
+ return f"Simple translation failed: {str(e)}"
205
+
206
  def perform_translation(self, input_sentences: list, source_code: str, target_code: str, paragraph_markers: list) -> str:
207
  """Perform the actual translation using NLLB model"""
208
  batch_size = 2 # Conservative batch size for stability
 
213
  batch_size = 1
214
 
215
  logger.info(f"Using batch size {batch_size} for average sentence length {avg_sentence_length:.1f} words")
216
+ logger.info(f"Translating from {source_code} to {target_code}")
217
 
218
  all_translations = []
219
 
 
221
  batch_sentences = input_sentences[i:i + batch_size]
222
 
223
  try:
224
+ # Tokenize input with source language
225
+ self.tokenizer.src_lang = source_code
226
  inputs = self.tokenizer(
227
  batch_sentences,
228
  return_tensors="pt",
 
231
  max_length=512
232
  ).to(self.device)
233
 
234
+ # Get target language token ID
235
+ try:
236
+ target_token_id = self.tokenizer.lang_code_to_id[target_code]
237
+ except KeyError:
238
+ logger.warning(f"Language code {target_code} not found in tokenizer, using default")
239
+ target_token_id = self.tokenizer.pad_token_id
240
+
241
  # Generate translation
242
  with torch.no_grad():
243
  translated_tokens = self.model.generate(
244
  **inputs,
245
+ forced_bos_token_id=target_token_id,
246
  max_length=512,
247
  num_beams=4,
248
  early_stopping=True,
249
+ do_sample=False,
250
+ pad_token_id=self.tokenizer.pad_token_id,
251
+ eos_token_id=self.tokenizer.eos_token_id
252
  )
253
 
254
  # Decode translations
 
257
  skip_special_tokens=True
258
  )
259
 
260
+ # Clean up translations (remove source language tokens if present)
261
+ cleaned_translations = []
262
+ for trans in translations:
263
+ # Remove any language tokens that might be in the output
264
+ cleaned = trans.strip()
265
+ if cleaned:
266
+ cleaned_translations.append(cleaned)
267
+ else:
268
+ cleaned_translations.append("Translation produced empty result")
269
+
270
+ all_translations.extend(cleaned_translations)
271
 
272
  # Progress logging
273
  if len(input_sentences) > 10:
 
277
  except Exception as e:
278
  logger.error(f"Translation error in batch: {str(e)}")
279
 
280
+ # Fallback: process sentences individually with simpler approach
281
  for single_sentence in batch_sentences:
282
  try:
283
+ # Set source language
284
+ self.tokenizer.src_lang = source_code
285
+
286
  inputs = self.tokenizer(
287
  single_sentence,
288
  return_tensors="pt",
 
290
  max_length=512
291
  ).to(self.device)
292
 
293
+ # Try different approaches for target language
294
+ generation_kwargs = {
295
+ "max_length": 512,
296
+ "num_beams": 2,
297
+ "early_stopping": True,
298
+ "do_sample": False,
299
+ "pad_token_id": self.tokenizer.pad_token_id,
300
+ "eos_token_id": self.tokenizer.eos_token_id
301
+ }
302
+
303
+ # Try with target language token first
304
+ try:
305
+ target_token_id = self.tokenizer.lang_code_to_id[target_code]
306
+ generation_kwargs["forced_bos_token_id"] = target_token_id
307
+ except KeyError:
308
+ logger.warning(f"Target language {target_code} not in tokenizer, trying without forced_bos_token_id")
309
+
310
  with torch.no_grad():
311
+ translated_tokens = self.model.generate(**inputs, **generation_kwargs)
 
 
 
 
 
 
312
 
313
  translation = self.tokenizer.decode(
314
  translated_tokens[0],
315
  skip_special_tokens=True
316
  )
317
 
318
+ # Clean the translation
319
+ cleaned_translation = translation.strip()
320
+ if cleaned_translation:
321
+ all_translations.append(cleaned_translation)
322
+ else:
323
+ all_translations.append("Empty translation result")
324
 
325
  except Exception as single_e:
326
+ logger.error(f"Failed to translate sentence '{single_sentence}': {str(single_e)}")
327
+ all_translations.append(f"Translation failed: {str(single_e)}")
328
 
329
  # Reconstruct formatting
330
  if paragraph_markers and len(all_translations) == len(paragraph_markers):
331
  final_translation = self.reconstruct_formatting(all_translations, paragraph_markers)
332
  else:
333
+ final_translation = ' '.join(all_translations) if all_translations else "Translation failed - no output generated"
334
 
335
  return final_translation
336
 
 
882
 
883
  # Initialize translator
884
  print("Initializing NLLB Translator...")
885
+ translator = NLLBTranslator(model_size="600M") # Use smaller model for stability
886
 
887
  # Create the Gradio app
888
  with gr.Blocks(title="NLLB Universal Translator", theme=gr.themes.Soft()) as demo: