jojonocode commited on
Commit
c3639fb
·
verified ·
1 Parent(s): 40f9d6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -5,7 +5,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  # --------------------------------------------------
6
  # Chargement du modèle NLLB
7
  # --------------------------------------------------
8
- MODEL_NAME = "facebook/nllb-200-3.3B"
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"🚀 Chargement du modèle {MODEL_NAME} sur {device}...")
@@ -37,23 +37,31 @@ def translate(text, src_lang, tgt_lang):
37
 
38
  try:
39
  # Configuration des langues
40
- src_code = LANGUAGES[src_lang]
41
- tgt_code = LANGUAGES[tgt_lang]
42
 
43
- # Préparation de l'entrée
44
  tokenizer.src_lang = src_code
45
- inputs = tokenizer(text, return_tensors="pt").to(device)
 
 
46
 
47
  # Génération
48
- translated_tokens = model.generate(
49
- **inputs,
50
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
51
- max_length=512
52
- )
 
53
 
54
  # Décodage
55
- return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
56
-
 
 
 
 
 
57
  except Exception as e:
58
  return f"❌ Erreur : {str(e)}"
59
 
 
5
  # --------------------------------------------------
6
  # Chargement du modèle NLLB
7
  # --------------------------------------------------
8
+ MODEL_NAME = "facebook/nllb-200-distilled-1.3B"
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"🚀 Chargement du modèle {MODEL_NAME} sur {device}...")
 
37
 
38
  try:
39
  # Configuration des langues
40
+ src_code = LANGUAGES.get(src_lang, "fra_Latn")
41
+ tgt_code = LANGUAGES.get(tgt_lang, "ewe_Latn")
42
 
43
+ # Indispensable pour NLLB : définir la langue source dans le tokenizer
44
  tokenizer.src_lang = src_code
45
+
46
+ # Tokenization
47
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
48
 
49
  # Génération
50
+ with torch.no_grad():
51
+ translated_tokens = model.generate(
52
+ **inputs,
53
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code],
54
+ max_length=512
55
+ )
56
 
57
  # Décodage
58
+ result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
59
+
60
+ # Nettoyage si le modèle renvoie du texte vide ou des espaces
61
+ if not result.strip():
62
+ return "⚠️ Le modèle n'a pas pu générer de traduction. Essayez une phrase plus simple."
63
+
64
+ return result
65
  except Exception as e:
66
  return f"❌ Erreur : {str(e)}"
67