jojonocode commited on
Commit
286dcb9
·
verified ·
1 Parent(s): 5c716ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -21
app.py CHANGED
@@ -4,49 +4,77 @@ import gradio as gr
4
 
5
  MODEL_NAME = "facebook/nllb-200-3.3B"
6
 
7
- # Pick device
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load tokenizer + model (consider float16 for GPU to save memory)
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, src_lang="fra_Latn")
12
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
  model.to(device)
14
 
15
- # Language codes used by NLLB
16
  LANG_CODES = {
17
  "fr->ee": ("fra_Latn", "ewe_Latn"),
18
  "ee->fr": ("ewe_Latn", "fra_Latn"),
19
  }
20
 
21
  def translate(text: str, direction: str, max_length: int = 256) -> str:
22
- if not text:
23
  return ""
24
- src, tgt = LANG_CODES[direction]
25
 
26
- # Tokenize and move to device
27
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
28
 
29
- # Force target language token id as BOS so model generates target language
30
- forced_bos_token_id = tokenizer.lang_code_to_id[tgt]
 
 
 
 
 
 
31
 
32
- generated = model.generate(
 
 
 
 
33
  **inputs,
34
  forced_bos_token_id=forced_bos_token_id,
35
  max_length=max_length,
36
- num_beams=4,
37
  )
38
- return tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
39
 
40
- # Gradio UI
 
 
 
 
 
41
  with gr.Blocks() as demo:
42
- gr.Markdown("## French ↔ Ewe translator (facebook/nllb-200-3.3B)")
43
  with gr.Row():
44
- inp = gr.Textbox(lines=6, placeholder="Enter text to translate...")
45
- out = gr.Textbox(lines=6, interactive=False)
46
- direction = gr.Radio(choices=["fr->ee", "ee->fr"], value="fr->ee", label="Direction")
47
- max_len = gr.Slider(minimum=32, maximum=1024, value=256, step=32, label="Max output tokens")
48
- translate_btn = gr.Button("Translate")
 
 
 
 
 
 
 
 
 
 
49
 
50
- translate_btn.click(lambda t, d, m: translate(t, d, m), inputs=[inp, direction, max_len], outputs=[out])
 
 
 
 
 
 
51
 
52
  demo.launch()
 
4
 
5
  MODEL_NAME = "facebook/nllb-200-3.3B"
6
 
7
+ # Sélection du device
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Chargement du modèle et du tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
  model.to(device)
14
 
15
+ # Dictionnaire des langues supportées
16
  LANG_CODES = {
17
  "fr->ee": ("fra_Latn", "ewe_Latn"),
18
  "ee->fr": ("ewe_Latn", "fra_Latn"),
19
  }
20
 
21
  def translate(text: str, direction: str, max_length: int = 256) -> str:
22
+ if not text.strip():
23
  return ""
 
24
 
25
+ src_lang, tgt_lang = LANG_CODES[direction]
 
26
 
27
+ # Tokenization avec la langue source explicitement définie
28
+ inputs = tokenizer(
29
+ text,
30
+ return_tensors="pt",
31
+ padding=True,
32
+ truncation=True,
33
+ src_lang=src_lang
34
+ ).to(device)
35
 
36
+ # On force la génération dans la langue cible
37
+ forced_bos_token_id = tokenizer.lang_code_to_id[tgt_lang]
38
+
39
+ # Génération
40
+ generated_tokens = model.generate(
41
  **inputs,
42
  forced_bos_token_id=forced_bos_token_id,
43
  max_length=max_length,
44
+ num_beams=4
45
  )
 
46
 
47
+ # Décodage
48
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
49
+ return translation.strip()
50
+
51
+
52
+ # === Interface Gradio ===
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("## 🌍 French ↔ Ewe Translator (facebook/nllb-200-3.3B)")
55
  with gr.Row():
56
+ inp = gr.Textbox(lines=6, label="Texte à traduire", placeholder="Entrez le texte ici...")
57
+ out = gr.Textbox(lines=6, label="Traduction", interactive=False)
58
+
59
+ direction = gr.Radio(
60
+ choices=["fr->ee", "ee->fr"],
61
+ value="fr->ee",
62
+ label="Direction de traduction"
63
+ )
64
+ max_len = gr.Slider(
65
+ minimum=32,
66
+ maximum=1024,
67
+ value=256,
68
+ step=32,
69
+ label="Longueur maximale de sortie"
70
+ )
71
 
72
+ translate_btn = gr.Button("🔁 Traduire")
73
+
74
+ translate_btn.click(
75
+ fn=translate,
76
+ inputs=[inp, direction, max_len],
77
+ outputs=[out]
78
+ )
79
 
80
  demo.launch()