STTR User commited on
Commit
448a6e3
·
1 Parent(s): 917f588

Add NLLB-200 Translation API

Browse files
Files changed (2) hide show
  1. app.py +58 -3
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,62 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load NLLB-200 (distilled for speed)
6
+ MODEL_NAME = "facebook/nllb-200-distilled-600M"
7
+ print(f"Loading {MODEL_NAME}...")
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = model.to(device)
12
+ print(f"Model loaded on {device}")
13
+
14
+ # Language codes
15
+ LANGS = {
16
+ "English": "eng_Latn",
17
+ "French": "fra_Latn",
18
+ "Arabic": "arb_Arab",
19
+ "Moroccan Arabic": "ary_Arab",
20
+ "Spanish": "spa_Latn",
21
+ "German": "deu_Latn",
22
+ "Italian": "ita_Latn",
23
+ "Portuguese": "por_Latn",
24
+ "Chinese": "zho_Hans",
25
+ "Japanese": "jpn_Jpan",
26
+ "Korean": "kor_Hang",
27
+ "Russian": "rus_Cyrl",
28
+ "Turkish": "tur_Latn",
29
+ "Dutch": "nld_Latn",
30
+ "Hindi": "hin_Deva",
31
+ }
32
+
33
+ def translate(text, src_lang, tgt_lang):
34
+ if not text.strip():
35
+ return ""
36
+
37
+ src_code = LANGS.get(src_lang, "eng_Latn")
38
+ tgt_code = LANGS.get(tgt_lang, "fra_Latn")
39
+
40
+ tokenizer.src_lang = src_code
41
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
42
+
43
+ forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
44
+
45
+ with torch.no_grad():
46
+ outputs = model.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512, num_beams=5)
47
+
48
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+
50
+ demo = gr.Interface(
51
+ fn=translate,
52
+ inputs=[
53
+ gr.Textbox(label="Text to translate", lines=3),
54
+ gr.Dropdown(list(LANGS.keys()), label="Source Language", value="English"),
55
+ gr.Dropdown(list(LANGS.keys()), label="Target Language", value="French"),
56
+ ],
57
+ outputs=gr.Textbox(label="Translation", lines=3),
58
+ title="NLLB-200 Translation API",
59
+ description="200 languages including Moroccan Arabic!",
60
+ )
61
 
 
62
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0
3
+ sentencepiece
4
+ protobuf
5
+ gradio