Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import MarianMTModel, MarianTokenizer
|
| 3 |
-
|
| 4 |
|
| 5 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
LANG_NAMES = {
|
| 7 |
"fr": "Français",
|
| 8 |
"en": "Anglais",
|
|
@@ -16,19 +20,27 @@ LANG_NAMES = {
|
|
| 16 |
"zh": "Chinois"
|
| 17 |
}
|
| 18 |
|
| 19 |
-
# Liste des modèles MarianMT
|
| 20 |
LANG_MODELS = {}
|
| 21 |
for src in LANG_NAMES.keys():
|
| 22 |
for tgt in LANG_NAMES.keys():
|
| 23 |
if src != tgt:
|
| 24 |
-
|
| 25 |
-
LANG_MODELS[(src, tgt)] = model_name
|
| 26 |
|
| 27 |
-
# Cache
|
| 28 |
model_cache = {}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def get_model(src, tgt):
|
| 31 |
-
"""Charge le modèle
|
| 32 |
if (src, tgt) not in LANG_MODELS:
|
| 33 |
return None, None
|
| 34 |
model_name = LANG_MODELS[(src, tgt)]
|
|
@@ -42,15 +54,11 @@ def get_model(src, tgt):
|
|
| 42 |
return model_cache.get(model_name, (None, None))
|
| 43 |
|
| 44 |
def translate(text, target_lang_name):
|
| 45 |
-
"""Traduit le texte vers la langue cible"""
|
| 46 |
# Trouver code ISO de la langue cible
|
| 47 |
target_lang = [code for code, name in LANG_NAMES.items() if name == target_lang_name][0]
|
| 48 |
|
| 49 |
-
# Détecter langue source
|
| 50 |
-
|
| 51 |
-
source_lang = detect(text)
|
| 52 |
-
except:
|
| 53 |
-
return "Impossible de détecter la langue."
|
| 54 |
|
| 55 |
if source_lang not in LANG_NAMES:
|
| 56 |
return f"Langue source '{source_lang}' non supportée."
|
|
@@ -58,12 +66,12 @@ def translate(text, target_lang_name):
|
|
| 58 |
if source_lang == target_lang:
|
| 59 |
return "La langue source et cible sont identiques."
|
| 60 |
|
| 61 |
-
# Charger le
|
| 62 |
tokenizer, model = get_model(source_lang, target_lang)
|
| 63 |
if tokenizer is None or model is None:
|
| 64 |
return f"Traduction {LANG_NAMES[source_lang]} → {LANG_NAMES[target_lang]} non supportée."
|
| 65 |
|
| 66 |
-
#
|
| 67 |
batch = tokenizer([text], return_tensors="pt", padding=True)
|
| 68 |
gen = model.generate(**batch)
|
| 69 |
translated = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
|
|
@@ -78,7 +86,7 @@ iface = gr.Interface(
|
|
| 78 |
],
|
| 79 |
outputs="text",
|
| 80 |
title="MyTranslator 🌍",
|
| 81 |
-
description="Traducteur multi-langues avec détection automatique
|
| 82 |
)
|
| 83 |
|
| 84 |
iface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import MarianMTModel, MarianTokenizer, AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
import torch
|
| 4 |
|
| 5 |
+
# Chargement du modèle IA de détection de langue
|
| 6 |
+
lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
|
| 7 |
+
lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
|
| 8 |
+
|
| 9 |
+
# Mapping code ISO → Nom complet
|
| 10 |
LANG_NAMES = {
|
| 11 |
"fr": "Français",
|
| 12 |
"en": "Anglais",
|
|
|
|
| 20 |
"zh": "Chinois"
|
| 21 |
}
|
| 22 |
|
| 23 |
+
# Liste des modèles MarianMT disponibles (dans les 2 sens)
|
| 24 |
LANG_MODELS = {}
|
| 25 |
for src in LANG_NAMES.keys():
|
| 26 |
for tgt in LANG_NAMES.keys():
|
| 27 |
if src != tgt:
|
| 28 |
+
LANG_MODELS[(src, tgt)] = f"Helsinki-NLP/opus-mt-{src}-{tgt}"
|
|
|
|
| 29 |
|
| 30 |
+
# Cache des modèles
|
| 31 |
model_cache = {}
|
| 32 |
|
| 33 |
+
def detect_language_ai(text):
|
| 34 |
+
"""Détecte la langue avec IA"""
|
| 35 |
+
inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True)
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
logits = lang_detect_model(**inputs).logits
|
| 38 |
+
predicted_id = torch.argmax(logits, dim=1).item()
|
| 39 |
+
label = lang_detect_model.config.id2label[predicted_id]
|
| 40 |
+
return label
|
| 41 |
+
|
| 42 |
def get_model(src, tgt):
|
| 43 |
+
"""Charge ou récupère le modèle MarianMT"""
|
| 44 |
if (src, tgt) not in LANG_MODELS:
|
| 45 |
return None, None
|
| 46 |
model_name = LANG_MODELS[(src, tgt)]
|
|
|
|
| 54 |
return model_cache.get(model_name, (None, None))
|
| 55 |
|
| 56 |
def translate(text, target_lang_name):
|
|
|
|
| 57 |
# Trouver code ISO de la langue cible
|
| 58 |
target_lang = [code for code, name in LANG_NAMES.items() if name == target_lang_name][0]
|
| 59 |
|
| 60 |
+
# Détecter langue source avec IA
|
| 61 |
+
source_lang = detect_language_ai(text)
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
if source_lang not in LANG_NAMES:
|
| 64 |
return f"Langue source '{source_lang}' non supportée."
|
|
|
|
| 66 |
if source_lang == target_lang:
|
| 67 |
return "La langue source et cible sont identiques."
|
| 68 |
|
| 69 |
+
# Charger le modèle de traduction
|
| 70 |
tokenizer, model = get_model(source_lang, target_lang)
|
| 71 |
if tokenizer is None or model is None:
|
| 72 |
return f"Traduction {LANG_NAMES[source_lang]} → {LANG_NAMES[target_lang]} non supportée."
|
| 73 |
|
| 74 |
+
# Traduire
|
| 75 |
batch = tokenizer([text], return_tensors="pt", padding=True)
|
| 76 |
gen = model.generate(**batch)
|
| 77 |
translated = tokenizer.batch_decode(gen, skip_special_tokens=True)[0]
|
|
|
|
| 86 |
],
|
| 87 |
outputs="text",
|
| 88 |
title="MyTranslator 🌍",
|
| 89 |
+
description="Traducteur multi-langues avec détection automatique IA."
|
| 90 |
)
|
| 91 |
|
| 92 |
iface.launch()
|