Spaces:
Sleeping
Sleeping
Update models/translation_loader.py
Browse files- models/translation_loader.py +62 -52
models/translation_loader.py
CHANGED
|
@@ -10,51 +10,59 @@ class TranslationLoader:
|
|
| 10 |
self,
|
| 11 |
model_name: str = "facebook/nllb-200-distilled-600M",
|
| 12 |
quantize: bool = True,
|
| 13 |
-
tgt_lang: str =
|
| 14 |
):
|
| 15 |
self.model_name = model_name
|
| 16 |
self.quantize = quantize
|
| 17 |
-
self.default_tgt = tgt_lang
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
self._load_pipeline()
|
| 21 |
-
|
| 22 |
-
# 2) Separately load AutoTokenizer so we can access lang_code_to_id
|
| 23 |
-
try:
|
| 24 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 25 |
-
# This mapping is used in the HF NLLB examples:
|
| 26 |
-
# tokenizer.lang_code_to_id["fra_Latn"] β token ID :contentReference[oaicite:1]{index=1}
|
| 27 |
-
self.lang_code_to_id = self.tokenizer.lang_code_to_id
|
| 28 |
-
logging.info("Loaded tokenizer.lang_code_to_id mapping")
|
| 29 |
-
except (AttributeError, ValueError):
|
| 30 |
-
# Fallback: some pipelines don't expose it, but the model config does
|
| 31 |
-
self.lang_code_to_id = self.pipeline.model.config.lang_code_to_id
|
| 32 |
-
logging.info("Using model.config.lang_code_to_id mapping")
|
| 33 |
-
|
| 34 |
-
# Precompute list of supported codes
|
| 35 |
-
self.lang_codes = list(self.lang_code_to_id.keys())
|
| 36 |
-
logging.info(f"Supported language codes (sample): {self.lang_codes[:5]}...")
|
| 37 |
-
|
| 38 |
-
def _load_pipeline(self):
|
| 39 |
try:
|
| 40 |
-
|
| 41 |
self.pipeline = pipeline(
|
| 42 |
"translation",
|
| 43 |
model=self.model_name,
|
| 44 |
tokenizer=self.model_name,
|
| 45 |
device_map="auto",
|
| 46 |
-
quantization_config=
|
| 47 |
)
|
| 48 |
-
logging.info(f"Loaded {self.model_name}
|
| 49 |
except Exception as e:
|
| 50 |
-
logging.warning(f"8-bit
|
| 51 |
self.pipeline = pipeline(
|
| 52 |
"translation",
|
| 53 |
model=self.model_name,
|
| 54 |
tokenizer=self.model_name,
|
| 55 |
device_map="auto",
|
| 56 |
)
|
| 57 |
-
logging.info(f"Loaded {self.model_name} in full precision")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def translate(
|
| 60 |
self,
|
|
@@ -63,42 +71,44 @@ class TranslationLoader:
|
|
| 63 |
tgt_lang: str = None,
|
| 64 |
):
|
| 65 |
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
"""
|
| 70 |
tgt = tgt_lang or self.default_tgt
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
if src_lang
|
|
|
|
|
|
|
| 74 |
sample = text[0] if isinstance(text, list) else text
|
| 75 |
try:
|
| 76 |
iso = detect(sample).lower()
|
| 77 |
-
#
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
logging.
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
#
|
| 90 |
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
|
| 91 |
|
| 92 |
def get_info(self):
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
""
|
| 96 |
-
|
| 97 |
-
quantized = getattr(model, "is_loaded_in_8bit", False)
|
| 98 |
-
device = getattr(model, "device", "auto")
|
| 99 |
return {
|
| 100 |
"model_name": self.model_name,
|
| 101 |
-
"quantized":
|
| 102 |
"device": str(device),
|
| 103 |
-
"
|
| 104 |
}
|
|
|
|
| 10 |
self,
|
| 11 |
model_name: str = "facebook/nllb-200-distilled-600M",
|
| 12 |
quantize: bool = True,
|
| 13 |
+
tgt_lang: str = None, # if None, weβll pick the Turkish code automatically
|
| 14 |
):
|
| 15 |
self.model_name = model_name
|
| 16 |
self.quantize = quantize
|
| 17 |
+
self.default_tgt = tgt_lang # may be None
|
| 18 |
|
| 19 |
+
# βββ Load the translation pipeline βββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
try:
|
| 21 |
+
bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
|
| 22 |
self.pipeline = pipeline(
|
| 23 |
"translation",
|
| 24 |
model=self.model_name,
|
| 25 |
tokenizer=self.model_name,
|
| 26 |
device_map="auto",
|
| 27 |
+
quantization_config=bnb_cfg,
|
| 28 |
)
|
| 29 |
+
logging.info(f"Loaded `{self.model_name}` with 8-bit={self.quantize}")
|
| 30 |
except Exception as e:
|
| 31 |
+
logging.warning(f"8-bit load failed ({e}); falling back to full-precision")
|
| 32 |
self.pipeline = pipeline(
|
| 33 |
"translation",
|
| 34 |
model=self.model_name,
|
| 35 |
tokenizer=self.model_name,
|
| 36 |
device_map="auto",
|
| 37 |
)
|
| 38 |
+
logging.info(f"Loaded `{self.model_name}` in full precision")
|
| 39 |
+
|
| 40 |
+
# βββ Load tokenizer & grab the lang_code_to_id mapping ββββββββββββ
|
| 41 |
+
try:
|
| 42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
| 43 |
+
logging.info(f"Tokenizer loaded for {self.model_name}")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logging.error(f"Cannot load tokenizer for {self.model_name}: {e}")
|
| 46 |
+
raise ValueError(f"Failed to load tokenizer: {e}")
|
| 47 |
+
|
| 48 |
+
if hasattr(self.tokenizer, "lang_code_to_id"):
|
| 49 |
+
self.lang_code_to_id = self.tokenizer.lang_code_to_id
|
| 50 |
+
logging.info("Using tokenizer.lang_code_to_id mapping")
|
| 51 |
+
else:
|
| 52 |
+
allowed = ", ".join(list(self.tokenizer.config.to_dict().keys())[:5])
|
| 53 |
+
raise AttributeError(
|
| 54 |
+
f"Model `{self.model_name}`βs tokenizer has no `lang_code_to_id`. "
|
| 55 |
+
"Use a model like NLLB-200 or M2M100 that supports language codes. "
|
| 56 |
+
f"(available config keys: {allowed}β¦)"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# βββ Auto-pick the Turkish target code if none was provided βββββββ
|
| 60 |
+
if self.default_tgt is None:
|
| 61 |
+
tur = [c for c in self.lang_code_to_id if c.lower().startswith("tr")]
|
| 62 |
+
if not tur:
|
| 63 |
+
raise ValueError(f"No Turkish code found in mapping for {self.model_name}")
|
| 64 |
+
self.default_tgt = tur[0]
|
| 65 |
+
logging.info(f"Default target set to `{self.default_tgt}`")
|
| 66 |
|
| 67 |
def translate(
|
| 68 |
self,
|
|
|
|
| 71 |
tgt_lang: str = None,
|
| 72 |
):
|
| 73 |
"""
|
| 74 |
+
- Auto-detects src_lang via langdetect if not given
|
| 75 |
+
- Uses default_tgt if tgt_lang is not passed
|
| 76 |
+
- Returns pipeline output (list of dicts with 'translation_text')
|
| 77 |
"""
|
| 78 |
tgt = tgt_lang or self.default_tgt
|
| 79 |
|
| 80 |
+
# βββ Source-language auto-detection βββββββββββββββββββββββββββββ
|
| 81 |
+
if src_lang:
|
| 82 |
+
src = src_lang
|
| 83 |
+
else:
|
| 84 |
sample = text[0] if isinstance(text, list) else text
|
| 85 |
try:
|
| 86 |
iso = detect(sample).lower()
|
| 87 |
+
# find codes starting with that ISO (e.g. "en"β["en","eng_Latn",β¦])
|
| 88 |
+
cand = [c for c in self.lang_code_to_id if c.lower().startswith(iso)]
|
| 89 |
+
if not cand:
|
| 90 |
+
raise LangDetectException(f"No mapping for ISO '{iso}'")
|
| 91 |
+
# prefer exact match, else first
|
| 92 |
+
exact = [c for c in cand if c.lower() == iso]
|
| 93 |
+
src = exact[0] if exact else cand[0]
|
| 94 |
+
logging.info(f"Detected src_lang={src} from ISO='{iso}'")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logging.warning(f"Language auto-detect failed ({e}); defaulting to English")
|
| 97 |
+
eng = [c for c in self.lang_code_to_id if c.lower().startswith("en")]
|
| 98 |
+
src = eng[0] if eng else list(self.lang_code_to_id)[0]
|
| 99 |
+
logging.info(f"Fallback src_lang={src}")
|
| 100 |
|
| 101 |
+
# βββ Perform translation call ββββββββββββββββββββββββββββββββββββ
|
| 102 |
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
|
| 103 |
|
| 104 |
def get_info(self):
|
| 105 |
+
"""Return model metadata for display in your sidebar."""
|
| 106 |
+
mdl = getattr(self.pipeline, "model", None)
|
| 107 |
+
q = getattr(mdl, "is_loaded_in_8bit", False)
|
| 108 |
+
device = getattr(mdl, "device", "auto")
|
|
|
|
|
|
|
| 109 |
return {
|
| 110 |
"model_name": self.model_name,
|
| 111 |
+
"quantized": q,
|
| 112 |
"device": str(device),
|
| 113 |
+
"default_target": self.default_tgt,
|
| 114 |
}
|