jojonocode commited on
Commit
bc1d457
·
verified ·
1 Parent(s): cb7837b

maj distil et pipe

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,25 +1,20 @@
1
  import torch
2
- from transformers import pipeline
3
  import gradio as gr
 
4
 
5
  # --------------------------------------------------
6
- # Chargement du pipeline NLLB
7
  # --------------------------------------------------
8
  MODEL_NAME = "facebook/nllb-200-distilled-1.3B"
9
 
10
- device = 0 if torch.cuda.is_available() else -1
11
- print(f"🚀 Chargement du modèle {MODEL_NAME} sur {'GPU' if device == 0 else 'CPU'}...")
12
 
13
- translator = pipeline(
14
- "translation",
15
- model=MODEL_NAME,
16
- device=device,
17
- src_lang="fra_Latn",
18
- tgt_lang="ewe_Latn"
19
- )
20
 
21
  # --------------------------------------------------
22
- # Dictionnaire de langues (tu peux en ajouter)
23
  # --------------------------------------------------
24
  LANGUAGES = {
25
  "Français": "fra_Latn",
@@ -41,13 +36,23 @@ def translate(text, src_lang, tgt_lang):
41
  return "⚠️ Veuillez entrer un texte à traduire."
42
 
43
  try:
44
- result = translator(
45
- text,
46
- src_lang=LANGUAGES[src_lang],
47
- tgt_lang=LANGUAGES[tgt_lang],
 
 
 
 
 
 
 
 
48
  max_length=512
49
  )
50
- return result[0]["translation_text"]
 
 
51
 
52
  except Exception as e:
53
  return f"❌ Erreur : {str(e)}"
 
1
  import torch
 
2
  import gradio as gr
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
 
5
  # --------------------------------------------------
6
+ # Chargement du modèle NLLB
7
  # --------------------------------------------------
8
  MODEL_NAME = "facebook/nllb-200-distilled-1.3B"
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print(f"🚀 Chargement du modèle {MODEL_NAME} sur {device}...")
12
 
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
 
 
 
 
 
15
 
16
  # --------------------------------------------------
17
+ # Dictionnaire de langues
18
  # --------------------------------------------------
19
  LANGUAGES = {
20
  "Français": "fra_Latn",
 
36
  return "⚠️ Veuillez entrer un texte à traduire."
37
 
38
  try:
39
+ # Configuration des langues
40
+ src_code = LANGUAGES[src_lang]
41
+ tgt_code = LANGUAGES[tgt_lang]
42
+
43
+ # Préparation de l'entrée
44
+ tokenizer.src_lang = src_code
45
+ inputs = tokenizer(text, return_tensors="pt").to(device)
46
+
47
+ # Génération
48
+ translated_tokens = model.generate(
49
+ **inputs,
50
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
51
  max_length=512
52
  )
53
+
54
+ # Décodage
55
+ return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
56
 
57
  except Exception as e:
58
  return f"❌ Erreur : {str(e)}"