File size: 3,035 Bytes
5ccd1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging
from typing import Union, List
from langdetect import detect, LangDetectException

from models.model_loader import ModelLoader
from models.model_selector import ModelSelector

logger = logging.getLogger(__name__)

class ModelManager:
    """
    Orchestrates model selection, loading, and auto-language detection.
    Exposes:
      - translate(text, src_lang=None, tgt_lang=None)
      - get_info()
    """
    def __init__(
        self,
        candidates: List[str] = None,
        quantize: bool = True,
        default_tgt: str = None,
    ):
        self.selector = ModelSelector(candidates, quantize)
        self.loader   = ModelLoader(quantize)
        self.tokenizer = None
        self.pipeline  = None
        self.lang_codes = []
        self.default_tgt = default_tgt  # e.g. "tur_Latn"
        self._load_best_model()

    def _load_best_model(self):
        model_name = self.selector.select()
        tok, pipe = self.loader.load(model_name)
        self.tokenizer = tok
        self.pipeline  = pipe
        self.lang_codes = list(tok.lang_code_to_id.keys())

        # Pick a Turkish code if not explicitly set
        if not self.default_tgt:
            tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
            if not tur:
                raise ValueError(f"No Turkish code found in {model_name}")
            self.default_tgt = tur[0]
        logger.info(f"Default target language: {self.default_tgt}")

    def translate(
        self,
        text: Union[str, List[str]],
        src_lang: str = None,
        tgt_lang: str = None,
    ):
        tgt = tgt_lang or self.default_tgt

        # Auto-detect source if missing
        if not src_lang:
            sample = text[0] if isinstance(text, list) else text
            try:
                iso = detect(sample).lower()
                candidates = [c for c in self.lang_codes if c.lower().startswith(iso)]
                if not candidates:
                    raise LangDetectException(f"No mapping for ISO '{iso}'")
                exact = [c for c in candidates if c.lower() == iso]
                src = exact[0] if exact else candidates[0]
                logger.info(f"Detected src_lang={src}")
            except Exception as e:
                logger.warning(f"Auto-detect failed ({e}); defaulting to English")
                eng = [c for c in self.lang_codes if c.lower().startswith("en")]
                src = eng[0] if eng else self.lang_codes[0]
        else:
            src = src_lang

        return self.pipeline(text, src_lang=src, tgt_lang=tgt)

    def get_info(self):
        """
        Returns a dict for your sidebar:
          { model_name, quantized, device, default_tgt }
        """
        mdl = getattr(self.pipeline, "model", None)
        return {
            "model": getattr(mdl, "name_or_path", None),
            "quantized": getattr(mdl, "is_loaded_in_8bit", False),
            "device": str(getattr(mdl, "device", "auto")),
            "default_tgt": self.default_tgt,
        }