Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Union, List | |
| from langdetect import detect, LangDetectException | |
| from models.model_loader import ModelLoader | |
| from models.model_selector import ModelSelector | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """ | |
| Orchestrates model selection, loading, and auto-language detection. | |
| Exposes: | |
| - translate(text, src_lang=None, tgt_lang=None) | |
| - get_info() | |
| """ | |
| def __init__( | |
| self, | |
| candidates: List[str] = None, | |
| quantize: bool = True, | |
| default_tgt: str = None, | |
| ): | |
| self.selector = ModelSelector(candidates, quantize) | |
| self.loader = ModelLoader(quantize) | |
| self.tokenizer = None | |
| self.pipeline = None | |
| self.lang_codes = [] | |
| self.default_tgt = default_tgt # e.g. "tur_Latn" | |
| self._load_best_model() | |
| def _load_best_model(self): | |
| model_name = self.selector.select() | |
| tok, pipe = self.loader.load(model_name) | |
| self.tokenizer = tok | |
| self.pipeline = pipe | |
| self.lang_codes = list(tok.lang_code_to_id.keys()) | |
| # Pick a Turkish code if not explicitly set | |
| if not self.default_tgt: | |
| tur = [c for c in self.lang_codes if c.lower().startswith("tr")] | |
| if not tur: | |
| raise ValueError(f"No Turkish code found in {model_name}") | |
| self.default_tgt = tur[0] | |
| logger.info(f"Default target language: {self.default_tgt}") | |
| def translate( | |
| self, | |
| text: Union[str, List[str]], | |
| src_lang: str = None, | |
| tgt_lang: str = None, | |
| ): | |
| tgt = tgt_lang or self.default_tgt | |
| # Auto-detect source if missing | |
| if not src_lang: | |
| sample = text[0] if isinstance(text, list) else text | |
| try: | |
| iso = detect(sample).lower() | |
| candidates = [c for c in self.lang_codes if c.lower().startswith(iso)] | |
| if not candidates: | |
| raise LangDetectException(f"No mapping for ISO '{iso}'") | |
| exact = [c for c in candidates if c.lower() == iso] | |
| src = exact[0] if exact else candidates[0] | |
| logger.info(f"Detected src_lang={src}") | |
| except Exception as e: | |
| logger.warning(f"Auto-detect failed ({e}); defaulting to English") | |
| eng = [c for c in self.lang_codes if c.lower().startswith("en")] | |
| src = eng[0] if eng else self.lang_codes[0] | |
| else: | |
| src = src_lang | |
| return self.pipeline(text, src_lang=src, tgt_lang=tgt) | |
| def get_info(self): | |
| """ | |
| Returns a dict for your sidebar: | |
| { model_name, quantized, device, default_tgt } | |
| """ | |
| mdl = getattr(self.pipeline, "model", None) | |
| return { | |
| "model": getattr(mdl, "name_or_path", None), | |
| "quantized": getattr(mdl, "is_loaded_in_8bit", False), | |
| "device": str(getattr(mdl, "device", "auto")), | |
| "default_tgt": self.default_tgt, | |
| } | |