kleervoyans commited on
Commit
922e95d
Β·
verified Β·
1 Parent(s): d970a26

Update models/translation_loader.py

Browse files
Files changed (1) hide show
  1. models/translation_loader.py +62 -52
models/translation_loader.py CHANGED
@@ -10,51 +10,59 @@ class TranslationLoader:
10
  self,
11
  model_name: str = "facebook/nllb-200-distilled-600M",
12
  quantize: bool = True,
13
- tgt_lang: str = "tur_Latn",
14
  ):
15
  self.model_name = model_name
16
  self.quantize = quantize
17
- self.default_tgt = tgt_lang
18
 
19
- # 1) Load translation pipeline (with optional 8-bit quantization)
20
- self._load_pipeline()
21
-
22
- # 2) Separately load AutoTokenizer so we can access lang_code_to_id
23
- try:
24
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
25
- # This mapping is used in the HF NLLB examples:
26
- # tokenizer.lang_code_to_id["fra_Latn"] β†’ token ID :contentReference[oaicite:1]{index=1}
27
- self.lang_code_to_id = self.tokenizer.lang_code_to_id
28
- logging.info("Loaded tokenizer.lang_code_to_id mapping")
29
- except (AttributeError, ValueError):
30
- # Fallback: some pipelines don't expose it, but the model config does
31
- self.lang_code_to_id = self.pipeline.model.config.lang_code_to_id
32
- logging.info("Using model.config.lang_code_to_id mapping")
33
-
34
- # Precompute list of supported codes
35
- self.lang_codes = list(self.lang_code_to_id.keys())
36
- logging.info(f"Supported language codes (sample): {self.lang_codes[:5]}...")
37
-
38
- def _load_pipeline(self):
39
  try:
40
- bnb_config = BitsAndBytesConfig(load_in_8bit=True)
41
  self.pipeline = pipeline(
42
  "translation",
43
  model=self.model_name,
44
  tokenizer=self.model_name,
45
  device_map="auto",
46
- quantization_config=bnb_config,
47
  )
48
- logging.info(f"Loaded {self.model_name} in 8-bit mode")
49
  except Exception as e:
50
- logging.warning(f"8-bit quantization failed ({e}), loading FP32 model")
51
  self.pipeline = pipeline(
52
  "translation",
53
  model=self.model_name,
54
  tokenizer=self.model_name,
55
  device_map="auto",
56
  )
57
- logging.info(f"Loaded {self.model_name} in full precision")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def translate(
60
  self,
@@ -63,42 +71,44 @@ class TranslationLoader:
63
  tgt_lang: str = None,
64
  ):
65
  """
66
- Translate `text` (str or list) from src_lang β†’ tgt_lang.
67
- If src_lang is None, auto-detect via langdetect.
68
- If tgt_lang is None, use the default (Turkish).
69
  """
70
  tgt = tgt_lang or self.default_tgt
71
 
72
- # Auto-detect source if not provided
73
- if src_lang is None:
 
 
74
  sample = text[0] if isinstance(text, list) else text
75
  try:
76
  iso = detect(sample).lower()
77
- # Find matching NLLB codes that start with the ISO
78
- candidates = [c for c in self.lang_codes if c.lower().startswith(iso)]
79
- # Prefer Latin-script variant if available
80
- src = next((c for c in candidates if "Latn" in c), None)
81
- src = src or (candidates[0] if candidates else "eng_Latn")
82
- logging.info(f"Auto-detected src_lang={src} (iso='{iso}')")
83
- except LangDetectException as e:
84
- logging.warning(f"langdetect failed ({e}), defaulting to eng_Latn")
85
- src = "eng_Latn"
86
- else:
87
- src = src_lang
 
 
88
 
89
- # Call the pipeline with both src_lang and tgt_lang
90
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
91
 
92
  def get_info(self):
93
- """
94
- Returns metadata for display in the UI sidebar.
95
- """
96
- model = getattr(self.pipeline, "model", None)
97
- quantized = getattr(model, "is_loaded_in_8bit", False)
98
- device = getattr(model, "device", "auto")
99
  return {
100
  "model_name": self.model_name,
101
- "quantized": quantized,
102
  "device": str(device),
103
- "default_tgt": self.default_tgt,
104
  }
 
10
  self,
11
  model_name: str = "facebook/nllb-200-distilled-600M",
12
  quantize: bool = True,
13
+ tgt_lang: str = None, # if None, we’ll pick the Turkish code automatically
14
  ):
15
  self.model_name = model_name
16
  self.quantize = quantize
17
+ self.default_tgt = tgt_lang # may be None
18
 
19
+ # ─── Load the translation pipeline ───────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
+ bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
22
  self.pipeline = pipeline(
23
  "translation",
24
  model=self.model_name,
25
  tokenizer=self.model_name,
26
  device_map="auto",
27
+ quantization_config=bnb_cfg,
28
  )
29
+ logging.info(f"Loaded `{self.model_name}` with 8-bit={self.quantize}")
30
  except Exception as e:
31
+ logging.warning(f"8-bit load failed ({e}); falling back to full-precision")
32
  self.pipeline = pipeline(
33
  "translation",
34
  model=self.model_name,
35
  tokenizer=self.model_name,
36
  device_map="auto",
37
  )
38
+ logging.info(f"Loaded `{self.model_name}` in full precision")
39
+
40
+ # ─── Load tokenizer & grab the lang_code_to_id mapping ────────────
41
+ try:
42
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
43
+ logging.info(f"Tokenizer loaded for {self.model_name}")
44
+ except Exception as e:
45
+ logging.error(f"Cannot load tokenizer for {self.model_name}: {e}")
46
+ raise ValueError(f"Failed to load tokenizer: {e}")
47
+
48
+ if hasattr(self.tokenizer, "lang_code_to_id"):
49
+ self.lang_code_to_id = self.tokenizer.lang_code_to_id
50
+ logging.info("Using tokenizer.lang_code_to_id mapping")
51
+ else:
52
+ allowed = ", ".join(list(self.tokenizer.config.to_dict().keys())[:5])
53
+ raise AttributeError(
54
+ f"Model `{self.model_name}`’s tokenizer has no `lang_code_to_id`. "
55
+ "Use a model like NLLB-200 or M2M100 that supports language codes. "
56
+ f"(available config keys: {allowed}…)"
57
+ )
58
+
59
+ # ─── Auto-pick the Turkish target code if none was provided ───────
60
+ if self.default_tgt is None:
61
+ tur = [c for c in self.lang_code_to_id if c.lower().startswith("tr")]
62
+ if not tur:
63
+ raise ValueError(f"No Turkish code found in mapping for {self.model_name}")
64
+ self.default_tgt = tur[0]
65
+ logging.info(f"Default target set to `{self.default_tgt}`")
66
 
67
  def translate(
68
  self,
 
71
  tgt_lang: str = None,
72
  ):
73
  """
74
+ - Auto-detects src_lang via langdetect if not given
75
+ - Uses default_tgt if tgt_lang is not passed
76
+ - Returns pipeline output (list of dicts with 'translation_text')
77
  """
78
  tgt = tgt_lang or self.default_tgt
79
 
80
+ # ─── Source-language auto-detection ─────────────────────────────
81
+ if src_lang:
82
+ src = src_lang
83
+ else:
84
  sample = text[0] if isinstance(text, list) else text
85
  try:
86
  iso = detect(sample).lower()
87
+ # find codes starting with that ISO (e.g. "en"β†’["en","eng_Latn",…])
88
+ cand = [c for c in self.lang_code_to_id if c.lower().startswith(iso)]
89
+ if not cand:
90
+ raise LangDetectException(f"No mapping for ISO '{iso}'")
91
+ # prefer exact match, else first
92
+ exact = [c for c in cand if c.lower() == iso]
93
+ src = exact[0] if exact else cand[0]
94
+ logging.info(f"Detected src_lang={src} from ISO='{iso}'")
95
+ except Exception as e:
96
+ logging.warning(f"Language auto-detect failed ({e}); defaulting to English")
97
+ eng = [c for c in self.lang_code_to_id if c.lower().startswith("en")]
98
+ src = eng[0] if eng else list(self.lang_code_to_id)[0]
99
+ logging.info(f"Fallback src_lang={src}")
100
 
101
+ # ─── Perform translation call ────────────────────────────────────
102
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
103
 
104
  def get_info(self):
105
+ """Return model metadata for display in your sidebar."""
106
+ mdl = getattr(self.pipeline, "model", None)
107
+ q = getattr(mdl, "is_loaded_in_8bit", False)
108
+ device = getattr(mdl, "device", "auto")
 
 
109
  return {
110
  "model_name": self.model_name,
111
+ "quantized": q,
112
  "device": str(device),
113
+ "default_target": self.default_tgt,
114
  }