import torch import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # -------------------------------------------------- # Chargement du modèle NLLB # -------------------------------------------------- MODEL_NAME = "facebook/nllb-200-distilled-1.3B" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Chargement du modèle {MODEL_NAME} sur {device}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) # -------------------------------------------------- # Dictionnaire de langues # -------------------------------------------------- LANGUAGES = { "Français": "fra_Latn", "Ewe": "ewe_Latn", "Fon": "fon_Latn", "Anglais": "eng_Latn", "Espagnol": "spa_Latn", "Allemand": "deu_Latn", "Swahili": "swh_Latn", "Lingala": "lin_Latn", "Portugais": "por_Latn" } # -------------------------------------------------- # Fonction de traduction # -------------------------------------------------- def translate(text, src_lang, tgt_lang="Ewe"): if not text.strip(): return "⚠️ Veuillez entrer un texte à traduire." try: # Configuration des langues (Source dynamique, Cible forcée sur Ewe) src_code = LANGUAGES.get(src_lang, "fra_Latn") tgt_code = LANGUAGES.get("Ewe", "ewe_Latn") # Tokenization avec spécification précise de la langue source # Note: Passer src_lang au tokenizer est crucial pour NLLB-200 inputs = tokenizer(text, return_tensors="pt", padding=True, src_lang=src_code).to(device) # Génération avec paramètres optimisés pour éviter les répétitions with torch.no_grad(): translated_tokens = model.generate( **inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), max_length=512, num_beams=5, no_repeat_ngram_size=3, repetition_penalty=1.5, # Augmenté pour éviter "etudiant etudiant" early_stopping=True, length_penalty=1.0 ) # Décodage result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] # Sécurité pour les sorties vides ou identiques à l'entrée if not result.strip() or result.strip().lower() == text.strip().lower(): # Si le modèle échoue, on tente une génération plus simple sans pénalités agressives with torch.no_grad(): translated_tokens = model.generate( **inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), max_length=512, num_beams=2 ) result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] if not result.strip(): return "⚠️ Traduction impossible pour ce texte." return result except Exception as e: return f"❌ Erreur : {str(e)}" # -------------------------------------------------- # Interface Gradio # -------------------------------------------------- with gr.Blocks(title="🌍 Traduction EWE") as demo: gr.Markdown( """
Traduction haute performance vers l'Ewe, le Fon et plus encore.