File size: 2,927 Bytes
6222cc6
 
 
 
 
 
 
 
 
 
 
ee5e504
6222cc6
 
 
 
 
 
ee5e504
6222cc6
 
 
 
 
 
 
 
 
 
 
ee5e504
 
6222cc6
 
 
 
ee5e504
6222cc6
 
 
ee5e504
6222cc6
 
 
 
 
 
ee5e504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6222cc6
 
 
 
 
 
 
 
 
 
ee5e504
 
 
6222cc6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Dialect-to-MSA (Modern Standard Arabic) conversion service.

Uses bayan10/dialect-to-msa-model (mT5 300M) to convert colloquial
Arabic dialects (Egyptian, Gulf, Levantine, Maghrebi) to formal MSA.

Singleton pattern — lazy-loads the model on first request to avoid
blocking server startup.
"""

import logging
import threading
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

logger = logging.getLogger(__name__)

_instance = None
_lock = threading.Lock()


class DialectConverter:
    """Converts dialect Arabic text to Modern Standard Arabic (MSA)."""

    PREFIX = "حوّل إلى الفصحى: "
    REPO_ID = "bayan10/dialect-to-msa-model"
    MAX_INPUT_LENGTH = 128
    MAX_OUTPUT_LENGTH = 128

    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        _dtype = torch.float16 if self.device == "cuda" else torch.float32
        logger.info(f"[DIALECT] Loading model from '{self.REPO_ID}'...")

        self.tokenizer = AutoTokenizer.from_pretrained(self.REPO_ID)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            self.REPO_ID, torch_dtype=_dtype
        ).to(self.device)
        self.model.eval()

        logger.info(f"[DIALECT] Model loaded successfully ({_dtype}).")

    def convert(self, dialect_text: str, num_beams: int = 4) -> str:
        """Convert a single dialect sentence to MSA."""
        if not dialect_text or not dialect_text.strip():
            return dialect_text

        try:
            input_text = self.PREFIX + dialect_text.strip()
            inputs = self.tokenizer(
                input_text,
                return_tensors="pt",
                max_length=self.MAX_INPUT_LENGTH,
                truncation=True,
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=self.MAX_OUTPUT_LENGTH,
                    num_beams=num_beams,
                    early_stopping=True,
                    no_repeat_ngram_size=3,
                )

            result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return result
        except Exception as e:
            logger.warning(f"[DIALECT] Conversion failed: {e}")
            return dialect_text

    def is_ready(self) -> bool:
        """Check if the model is loaded and ready."""
        return self.model is not None and self.tokenizer is not None


def get_dialect_model() -> DialectConverter:
    """Get or create the singleton DialectConverter instance."""
    global _instance
    if _instance is None:
        with _lock:
            if _instance is None:
                _instance = DialectConverter()
    return _instance


def is_loaded() -> bool:
    """Check if the dialect model is loaded (without triggering lazy load)."""
    return _instance is not None and _instance.is_ready()