RamizXhah commited on
Commit
a34def2
·
verified ·
1 Parent(s): ba8bdb5

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +62 -67
translation.py CHANGED
@@ -1,56 +1,54 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
2
 
3
- # 1. TEMPORARY FIX: Switched to the smaller MarianMT model to avoid memory crashes.
4
- # If this works, the NLLB model is too large for the current Space hardware.
5
- # NLLB MODEL (Large): model_name = "facebook/nllb-200-distilled-600M"
6
- model_name = "Helsinki-NLP/opus-mt-en-ur"
7
- _tokenizer = None
8
- _model = None
9
-
10
- # MarianMT models handle the reverse translation (Urdu-English) by using a separate model pair.
11
- # We will load the reverse model on demand.
12
- REVERSE_MODEL_NAME = "Helsinki-NLP/opus-mt-ur-en"
13
- _reverse_tokenizer = None
14
- _reverse_model = None
15
-
16
-
17
- def _load_translation_resources():
18
- """Loads the main EN-UR model resources (Helsinki-NLP/opus-mt-en-ur)."""
19
- global _tokenizer, _model
20
- if _tokenizer is None or _model is None:
21
- _tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- _model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
- return _tokenizer, _model
24
-
25
- def _load_reverse_translation_resources():
26
- """Loads the UR-EN model resources (Helsinki-NLP/opus-mt-ur-en)."""
27
- global _reverse_tokenizer, _reverse_model
28
- if _reverse_tokenizer is None or _reverse_model is None:
29
- _reverse_tokenizer = AutoTokenizer.from_pretrained(REVERSE_MODEL_NAME)
30
- _reverse_model = AutoModelForSeq2SeqLM.from_pretrained(REVERSE_MODEL_NAME)
31
- return _reverse_tokenizer, _reverse_model
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def translate_to_urdu(text):
35
- """Translates English text to Urdu using the Helsinki-NLP/opus-mt-en-ur model."""
36
- # MarianMT models often don't require explicit src_lang or forced_bos_token_id
37
- # for single-pair models, but we use the target language code for safety.
38
- tokenizer, model = _load_translation_resources()
39
 
40
  try:
41
- # NOTE: MarianMT tokens are often '>>ur<<' for the target language.
42
- input_ids = tokenizer(
43
- text,
44
- # For MarianMT, the source language is implicit in the model name (en-ur)
45
- # but we use the target language token to guide generation.
46
- text_target=[""] * len(text) if isinstance(text, list) else "",
47
- return_tensors='pt'
48
- ).input_ids
49
-
50
- # We set the forced_bos_token_id to the target language code 'ur'
51
  generated_tokens = model.generate(
52
  input_ids,
53
- forced_bos_token_id=tokenizer.get_lang_id("ur"),
 
54
  num_beams=5,
55
  max_length=128
56
  )
@@ -58,26 +56,23 @@ def translate_to_urdu(text):
58
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
59
 
60
  except Exception as exc:
61
- # Raised as RuntimeError to be caught in app.py and sent as 500
62
- raise RuntimeError("Translation to Urdu failed (MarianMT check)") from exc
63
-
64
 
65
  def translate_to_english(text):
66
- """Translates Urdu text to English using the Helsinki-NLP/opus-mt-ur-en model."""
67
- # Note: This loads a separate UR-EN model pair.
68
- tokenizer, model = _load_reverse_translation_resources()
69
 
70
  try:
71
- input_ids = tokenizer(
72
- text,
73
- text_target=[""] * len(text) if isinstance(text, list) else "",
74
- return_tensors='pt'
75
- ).input_ids
76
 
77
- # We set the forced_bos_token_id to the target language code 'en'
78
  generated_tokens = model.generate(
79
  input_ids,
80
- forced_bos_token_id=tokenizer.get_lang_id("en"),
 
81
  num_beams=5,
82
  max_length=128
83
  )
@@ -85,19 +80,19 @@ def translate_to_english(text):
85
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
86
 
87
  except Exception as exc:
88
- raise RuntimeError("Translation to English failed (MarianMT check)") from exc
89
-
90
 
91
  # --- Example Usage ---
92
  if __name__ == "__main__":
93
- input_text = "The study investigates the correlation between socioeconomic status and academic achievement."
94
- translated_text = translate_to_urdu(input_text)
 
95
 
96
- print(f"Original (English): {input_text}")
97
- print(f"Translated (Urdu): {translated_text}")
98
 
99
  # Test Urdu to English translation
100
- urdu_text = translated_text
101
- back_to_english = translate_to_english(urdu_text)
102
- print(f"\nOriginal (Urdu): {urdu_text}")
103
- print(f"Translated back (English): {back_to_english}")
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
 
4
+ # --- Model Definitions ---
5
+ # Using separate, small MarianMT models to guarantee stability and avoid memory crashes.
6
+ EN_UR_MODEL_NAME = "Helsinki-NLP/opus-mt-en-ur"
7
+ UR_EN_MODEL_NAME = "Helsinki-NLP/opus-mt-ur-en"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Lazy-loading variables for EN-UR model
10
+ _en_ur_tokenizer = None
11
+ _en_ur_model = None
12
+
13
+ # Lazy-loading variables for UR-EN model
14
+ _ur_en_tokenizer = None
15
+ _ur_en_model = None
16
+
17
+ # --- Resource Loading Functions ---
18
+
19
+ def _load_en_ur_resources():
20
+ """Loads the English-to-Urdu MarianMT model."""
21
+ global _en_ur_tokenizer, _en_ur_model
22
+ if _en_ur_tokenizer is None or _en_ur_model is None:
23
+ _en_ur_tokenizer = AutoTokenizer.from_pretrained(EN_UR_MODEL_NAME)
24
+ _en_ur_model = AutoModelForSeq2SeqLM.from_pretrained(EN_UR_MODEL_NAME)
25
+ return _en_ur_tokenizer, _en_ur_model
26
+
27
+ def _load_ur_en_resources():
28
+ """Loads the Urdu-to-English MarianMT model."""
29
+ global _ur_en_tokenizer, _ur_en_model
30
+ if _ur_en_tokenizer is None or _ur_en_model is None:
31
+ _ur_en_tokenizer = AutoTokenizer.from_pretrained(UR_EN_MODEL_NAME)
32
+ _ur_en_model = AutoModelForSeq2SeqLM.from_pretrained(UR_EN_MODEL_NAME)
33
+ return _ur_en_tokenizer, _ur_en_model
34
+
35
+ # --- Translation Functions ---
36
 
37
  def translate_to_urdu(text):
38
+ """Translates English text to Urdu."""
39
+ tokenizer, model = _load_en_ur_resources()
 
 
40
 
41
  try:
42
+ # MarianMT requires the target language token to start the generation
43
+ # We use '>>ur<<' as the start token for this model pair.
44
+ TGT_LANG_TOKEN = '>>ur<<'
45
+
46
+ input_ids = tokenizer.encode(text, return_tensors='pt')
47
+
 
 
 
 
48
  generated_tokens = model.generate(
49
  input_ids,
50
+ # CRITICAL FIX: Use the specific language token ID for MarianMT
51
+ decoder_start_token_id=tokenizer.lang_code_to_id[TGT_LANG_TOKEN],
52
  num_beams=5,
53
  max_length=128
54
  )
 
56
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
57
 
58
  except Exception as exc:
59
+ raise RuntimeError("Translation to Urdu failed (MarianMT Final)") from exc
 
 
60
 
61
  def translate_to_english(text):
62
+ """Translates Urdu text to English."""
63
+ tokenizer, model = _load_ur_en_resources()
 
64
 
65
  try:
66
+ # MarianMT requires the target language token to start the generation
67
+ # We use '>>en<<' as the start token for this reverse model pair.
68
+ TGT_LANG_TOKEN = '>>en<<'
69
+
70
+ input_ids = tokenizer.encode(text, return_tensors='pt')
71
 
 
72
  generated_tokens = model.generate(
73
  input_ids,
74
+ # CRITICAL FIX: Use the specific language token ID for MarianMT
75
+ decoder_start_token_id=tokenizer.lang_code_to_id[TGT_LANG_TOKEN],
76
  num_beams=5,
77
  max_length=128
78
  )
 
80
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
81
 
82
  except Exception as exc:
83
+ raise RuntimeError("Translation to English failed (MarianMT Final)") from exc
 
84
 
85
  # --- Example Usage ---
86
  if __name__ == "__main__":
87
+ # Test English to Urdu
88
+ input_text_en = "This is a final test of the translation API."
89
+ translated_text_ur = translate_to_urdu(input_text_en)
90
 
91
+ print(f"Original (English): {input_text_en}")
92
+ print(f"Translated (Urdu): {translated_text_ur}")
93
 
94
  # Test Urdu to English translation
95
+ input_text_ur = "یہ ایپلیکیشن کامیابی سے چل رہی ہے۔"
96
+ translated_text_en = translate_to_english(input_text_ur)
97
+ print(f"\nOriginal (Urdu): {input_text_ur}")
98
+ print(f"Translated back (English): {translated_text_en}")