File size: 3,137 Bytes
78ba245
21b9fb8
 
 
78ba245
21b9fb8
63e0f06
21b9fb8
 
 
 
 
52ec8a5
8f6e3cb
 
 
 
 
52ec8a5
8f6e3cb
 
52ec8a5
8f6e3cb
 
 
9c97528
 
 
 
 
 
8f6e3cb
78ba245
6bcd01d
d55ce43
21b9fb8
78ba245
d55ce43
 
78ba245
d55ce43
 
 
 
 
21b9fb8
 
bef4bfe
21b9fb8
 
 
 
 
 
d55ce43
21b9fb8
 
fcd0983
21b9fb8
d55ce43
21b9fb8
d55ce43
78ba245
 
6bcd01d
 
d01f8b2
6bcd01d
444c9fa
43c46bf
 
21b9fb8
3da7285
af335bf
78ba245
e81430e
6bcd01d
78ba245
ef7f4c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import NllbTokenizer
import ctranslate2
from huggingface_hub import snapshot_download


MODEL_ID = "Tamazight-NLP/NLLB-200-600M-Tamazight-All-Data-3-epoch-ct2-int8"

snapshot_download(MODEL_ID, local_dir=MODEL_ID.split('/')[1])

translator = ctranslate2.Translator(MODEL_ID.split('/')[1])
nllb_tokenizer = NllbTokenizer.from_pretrained(MODEL_ID.split('/')[1])

NLLB_LANG_MAPPING = {
    "English": "eng_Latn",
    "Standard Moroccan Tamazight": "tzm_Tfng",
    "Tachelhit/Central Atlas Tamazight": "taq_Tfng",
    "Tachelhit/Central Atlas Tamazight (Latin)": "taq_Latn",
    "Tarifit": "kab_Tfng",
    "Tarifit (Latin)": "kab_Latn",
    "Moroccan Darija": "ary_Arab",
    "Modern Standard Arabic": "arb_Arab",
    "Catalan": "cat_Latn",
    "Spanish": "spa_Latn",
    "French": "fra_Latn",
    "German": "deu_Latn",
    "Dutch": "nld_Latn",
    "Russian": "rus_Cyrl",
    "Italian": "ita_Latn",
    "Turkish": "tur_Latn",
    "Esperanto": "epo_Latn"
}


def translate(text, source_lang="English", target_lang="Tachelhit/Central Atlas Tamazight",
              max_length=237, num_beams=4, repetition_penalty=1.0):
    """
    Translate multi-line text while preserving line breaks.
    Each line is translated independently.
    """
    translations = []
    for line in text.split("\n"):
        if line.strip() == "":
            translations.append("")  # preserve empty lines
        else:
            nllb_tokenizer.src_lang = NLLB_LANG_MAPPING[source_lang]
            source = nllb_tokenizer.convert_ids_to_tokens(nllb_tokenizer.encode(line))
            target_prefix = [NLLB_LANG_MAPPING[target_lang]]
            results = translator.translate_batch(
                [source],
                target_prefix=[target_prefix],
                max_decoding_length=max_length,
                beam_size=num_beams,
                repetition_penalty=repetition_penalty,
            )
            target = results[0].hypotheses[0][1:]
    
            translation = nllb_tokenizer.decode(nllb_tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True)

            translations.append(translation)

    return "\n".join(translations)


gradio_ui= gr.Interface(
    fn=translate,
    title="NLLB Tamazight Translation Demo",
    inputs= [
        gr.components.Textbox(label="Text", lines=4, placeholder="ⵙⵙⴽⵛⵎ ⴰⴹⵕⵉⵚ...\nEnter text to translate..."),
        gr.components.Dropdown(label="Source Language", choices=list(NLLB_LANG_MAPPING.keys()), value="English"),
        gr.components.Dropdown(label="Target Language", choices=list(NLLB_LANG_MAPPING.keys()), value="Standard Moroccan Tamazight"),
        gr.components.Slider(1, 400, value=237, step=10, label="Max Length (in tokens). Increase in case the output looks truncated."),
        gr.components.Slider(1, 25, value=4, step=1, label="Number of beams. Higher values might improve translation accuracy at the cost of speed."),
        gr.components.Slider(1, 10, value=1.0, step=0.1, label="Repetition penalty."),
    ],
    outputs=gr.components.Textbox(label="Translated text", lines=4)
)

gradio_ui.launch()