File size: 2,779 Bytes
9e5cadc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import functools

# Define a subset of popular languages mapped to FLORES-200 codes for better UX.
# NLLB supports 200+, but a dropdown of 200 items can be unwieldy.
# Codes reference: https://github.com/facebookresearch/flores/blob/main/flores200/README.md
LANGUAGE_CODES = {
    "English": "eng_Latn",
    "French": "fra_Latn",
    "Spanish": "spa_Latn",
    "German": "deu_Latn",
    "Chinese (Simplified)": "zho_Hans",
    "Chinese (Traditional)": "zho_Hant",
    "Hindi": "hin_Deva",
    "Arabic": "arb_Arab",
    "Russian": "rus_Cyrl",
    "Portuguese": "por_Latn",
    "Japanese": "jpn_Jpan",
    "Korean": "kor_Hang",
    "Italian": "ita_Latn",
    "Dutch": "nld_Latn",
    "Turkish": "tur_Latn",
    "Vietnamese": "vie_Latn",
    "Indonesian": "ind_Latn",
    "Persian": "pes_Arab",
    "Polish": "pol_Latn",
    "Ukrainian": "ukr_Cyrl",
    "Swahili": "swh_Latn",
    "Urdu": "urd_Arab",
    "Bengali": "ben_Beng",
    "Tamil": "tam_Taml"
}

MODEL_NAME = "facebook/nllb-200-distilled-600M"
_model = None
_tokenizer = None

def get_device():
    """Determines the best available device."""
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"

def load_model():
    """
    Loads the model and tokenizer lazily (singleton pattern).
    """
    global _model, _tokenizer
    if _model is None:
        print(f"Loading {MODEL_NAME}...")
        device = get_device()
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
        print("Model loaded successfully.")
    return _model, _tokenizer

def translate_text(text, src_lang_name, tgt_lang_name):
    """
    Performs the translation using NLLB.
    """
    if not text:
        return ""

    try:
        model, tokenizer = load_model()
        device = model.device

        # Get NLLB specific codes
        src_code = LANGUAGE_CODES.get(src_lang_name, "eng_Latn")
        tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "fra_Latn")

        # Prepare inputs
        tokenizer.src_lang = src_code
        inputs = tokenizer(text, return_tensors="pt").to(device)

        # Generate translation
        # forced_bos_token_id forces the model to start generating in the target language
        generated_tokens = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code],
            max_length=200
        )

        # Decode output
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        return result

    except Exception as e:
        return f"Error during translation: {str(e)}"