jb100 commited on
Commit
7f47db7
·
verified ·
1 Parent(s): c9eafad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -37
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
@@ -183,14 +185,16 @@ class NLLBTranslator:
183
  max_length=512
184
  ).to(self.device)
185
 
186
- # Generate without forced language token first
187
  with torch.no_grad():
188
  outputs = self.model.generate(
189
  **inputs,
190
  max_length=512,
191
  num_beams=5,
192
  early_stopping=True,
193
- do_sample=False
 
 
194
  )
195
 
196
  # Decode
@@ -231,25 +235,38 @@ class NLLBTranslator:
231
  max_length=512
232
  ).to(self.device)
233
 
234
- # Get target language token ID
 
235
  try:
236
- target_token_id = self.tokenizer.lang_code_to_id[target_code]
237
- except KeyError:
238
- logger.warning(f"Language code {target_code} not found in tokenizer, using default")
239
- target_token_id = self.tokenizer.pad_token_id
 
 
 
 
 
 
 
 
240
 
241
  # Generate translation
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  with torch.no_grad():
243
- translated_tokens = self.model.generate(
244
- **inputs,
245
- forced_bos_token_id=target_token_id,
246
- max_length=512,
247
- num_beams=4,
248
- early_stopping=True,
249
- do_sample=False,
250
- pad_token_id=self.tokenizer.pad_token_id,
251
- eos_token_id=self.tokenizer.eos_token_id
252
- )
253
 
254
  # Decode translations
255
  translations = self.tokenizer.batch_decode(
@@ -257,10 +274,9 @@ class NLLBTranslator:
257
  skip_special_tokens=True
258
  )
259
 
260
- # Clean up translations (remove source language tokens if present)
261
  cleaned_translations = []
262
  for trans in translations:
263
- # Remove any language tokens that might be in the output
264
  cleaned = trans.strip()
265
  if cleaned:
266
  cleaned_translations.append(cleaned)
@@ -290,25 +306,17 @@ class NLLBTranslator:
290
  max_length=512
291
  ).to(self.device)
292
 
293
- # Try different approaches for target language
294
- generation_kwargs = {
295
- "max_length": 512,
296
- "num_beams": 2,
297
- "early_stopping": True,
298
- "do_sample": False,
299
- "pad_token_id": self.tokenizer.pad_token_id,
300
- "eos_token_id": self.tokenizer.eos_token_id
301
- }
302
-
303
- # Try with target language token first
304
- try:
305
- target_token_id = self.tokenizer.lang_code_to_id[target_code]
306
- generation_kwargs["forced_bos_token_id"] = target_token_id
307
- except KeyError:
308
- logger.warning(f"Target language {target_code} not in tokenizer, trying without forced_bos_token_id")
309
-
310
  with torch.no_grad():
311
- translated_tokens = self.model.generate(**inputs, **generation_kwargs)
 
 
 
 
 
 
 
 
312
 
313
  translation = self.tokenizer.decode(
314
  translated_tokens[0],
 
1
+ # code v13
2
+
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
185
  max_length=512
186
  ).to(self.device)
187
 
188
+ # Generate without forced language token to avoid tokenizer issues
189
  with torch.no_grad():
190
  outputs = self.model.generate(
191
  **inputs,
192
  max_length=512,
193
  num_beams=5,
194
  early_stopping=True,
195
+ do_sample=False,
196
+ pad_token_id=self.tokenizer.pad_token_id,
197
+ eos_token_id=self.tokenizer.eos_token_id
198
  )
199
 
200
  # Decode
 
235
  max_length=512
236
  ).to(self.device)
237
 
238
+ # Get target language token ID using different methods
239
+ target_token_id = None
240
  try:
241
+ # Method 1: Try lang_code_to_id
242
+ if hasattr(self.tokenizer, 'lang_code_to_id'):
243
+ target_token_id = self.tokenizer.lang_code_to_id[target_code]
244
+ # Method 2: Try convert_tokens_to_ids
245
+ elif hasattr(self.tokenizer, 'convert_tokens_to_ids'):
246
+ target_token_id = self.tokenizer.convert_tokens_to_ids(target_code)
247
+ # Method 3: Try getting from vocabulary
248
+ else:
249
+ target_token_id = self.tokenizer.get_vocab().get(target_code)
250
+ except (KeyError, AttributeError):
251
+ logger.warning(f"Could not find target language token for {target_code}")
252
+ target_token_id = None
253
 
254
  # Generate translation
255
+ generation_kwargs = {
256
+ "max_length": 512,
257
+ "num_beams": 4,
258
+ "early_stopping": True,
259
+ "do_sample": False,
260
+ "pad_token_id": self.tokenizer.pad_token_id,
261
+ "eos_token_id": self.tokenizer.eos_token_id
262
+ }
263
+
264
+ # Only add forced_bos_token_id if we found a valid target token
265
+ if target_token_id is not None:
266
+ generation_kwargs["forced_bos_token_id"] = target_token_id
267
+
268
  with torch.no_grad():
269
+ translated_tokens = self.model.generate(**inputs, **generation_kwargs)
 
 
 
 
 
 
 
 
 
270
 
271
  # Decode translations
272
  translations = self.tokenizer.batch_decode(
 
274
  skip_special_tokens=True
275
  )
276
 
277
+ # Clean up translations
278
  cleaned_translations = []
279
  for trans in translations:
 
280
  cleaned = trans.strip()
281
  if cleaned:
282
  cleaned_translations.append(cleaned)
 
306
  max_length=512
307
  ).to(self.device)
308
 
309
+ # Use simple generation without forced language tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  with torch.no_grad():
311
+ translated_tokens = self.model.generate(
312
+ **inputs,
313
+ max_length=512,
314
+ num_beams=2,
315
+ early_stopping=True,
316
+ do_sample=False,
317
+ pad_token_id=self.tokenizer.pad_token_id,
318
+ eos_token_id=self.tokenizer.eos_token_id
319
+ )
320
 
321
  translation = self.tokenizer.decode(
322
  translated_tokens[0],