RamizXhah commited on
Commit
4268eec
·
verified ·
1 Parent(s): fe61e5c

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +15 -22
translation.py CHANGED
@@ -2,15 +2,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
 
4
  # --- Model Definition ---
5
- # Using NLLB-200 Distilled (600M) for high-quality, natural translations.
6
- # This single model handles both English (eng_Latn) and Urdu (urd_Arab).
7
  MODEL_NAME = "facebook/nllb-200-distilled-600M"
8
 
9
  _tokenizer = None
10
  _model = None
11
 
12
  def _load_model_resources():
13
- """Loads the NLLB tokenizer and model (cached)."""
14
  global _tokenizer, _model
15
  if _tokenizer is None or _model is None:
16
  print(f"Loading translation model: {MODEL_NAME}...")
@@ -29,15 +28,18 @@ def translate_to_urdu(text):
29
  # 2. Prepare inputs
30
  inputs = tokenizer(text, return_tensors="pt")
31
 
32
- # 3. Generate output with target language forced to Urdu
 
 
 
 
33
  generated_tokens = model.generate(
34
  **inputs,
35
- forced_bos_token_id=tokenizer.lang_code_to_id["urd_Arab"],
36
  max_length=128,
37
  num_beams=5
38
  )
39
 
40
- # 4. Decode result
41
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
42
 
43
  except Exception as exc:
@@ -54,28 +56,19 @@ def translate_to_english(text):
54
  # 2. Prepare inputs
55
  inputs = tokenizer(text, return_tensors="pt")
56
 
57
- # 3. Generate output with target language forced to English
 
 
 
 
58
  generated_tokens = model.generate(
59
  **inputs,
60
- forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"],
61
  max_length=128,
62
  num_beams=5
63
  )
64
 
65
- # 4. Decode result
66
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
67
 
68
  except Exception as exc:
69
- raise RuntimeError(f"NLLB Translation to English failed: {str(exc)}")
70
-
71
- # --- Test Logic (Runs only if you execute this file directly) ---
72
- if __name__ == "__main__":
73
- print("--- Testing NLLB Model ---")
74
- sample_text = "The quick brown fox jumps over the lazy dog."
75
- print(f"Original: {sample_text}")
76
-
77
- urdu_text = translate_to_urdu(sample_text)
78
- print(f"Urdu: {urdu_text}")
79
-
80
- english_text = translate_to_english(urdu_text)
81
- print(f"Back to English: {english_text}")
 
2
  import torch
3
 
4
  # --- Model Definition ---
5
+ # Using NLLB-200 Distilled (600M)
 
6
  MODEL_NAME = "facebook/nllb-200-distilled-600M"
7
 
8
  _tokenizer = None
9
  _model = None
10
 
11
  def _load_model_resources():
12
+ """Loads the NLLB tokenizer and model."""
13
  global _tokenizer, _model
14
  if _tokenizer is None or _model is None:
15
  print(f"Loading translation model: {MODEL_NAME}...")
 
28
  # 2. Prepare inputs
29
  inputs = tokenizer(text, return_tensors="pt")
30
 
31
+ # 3. Get the token ID for Urdu
32
+ # FIX: Use convert_tokens_to_ids instead of lang_code_to_id
33
+ target_lang_id = tokenizer.convert_tokens_to_ids("urd_Arab")
34
+
35
+ # 4. Generate output
36
  generated_tokens = model.generate(
37
  **inputs,
38
+ forced_bos_token_id=target_lang_id,
39
  max_length=128,
40
  num_beams=5
41
  )
42
 
 
43
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
44
 
45
  except Exception as exc:
 
56
  # 2. Prepare inputs
57
  inputs = tokenizer(text, return_tensors="pt")
58
 
59
+ # 3. Get the token ID for English
60
+ # FIX: Use convert_tokens_to_ids instead of lang_code_to_id
61
+ target_lang_id = tokenizer.convert_tokens_to_ids("eng_Latn")
62
+
63
+ # 4. Generate output
64
  generated_tokens = model.generate(
65
  **inputs,
66
+ forced_bos_token_id=target_lang_id,
67
  max_length=128,
68
  num_beams=5
69
  )
70
 
 
71
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
72
 
73
  except Exception as exc:
74
+ raise RuntimeError(f"NLLB Translation to English failed: {str(exc)}")