Spaces:
Sleeping
Sleeping
File size: 1,362 Bytes
1b1d5c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Détecteur de langue
lang_detector = pipeline(
"text-classification",
model="papluca/xlm-roberta-base-language-detection",
device=-1
)
NLLB_LANGS = {
"fr": "fra_Latn",
"en": "eng_Latn",
"es": "spa_Latn",
"de": "deu_Latn",
"ar": "arb_Arab"
}
def detect_language(text):
result = lang_detector(text[:512])[0]
return result["label"], round(result["score"],4)
# NLLB modèle
model_nllb_name = "facebook/nllb-200-distilled-600M"
tokenizer_nllb = AutoTokenizer.from_pretrained(model_nllb_name)
model_nllb = AutoModelForSeq2SeqLM.from_pretrained(model_nllb_name)
# Traduction
def translate_text(text, src_lang, tgt_lang, max_length=512):
tokenizer_nllb.src_lang = src_lang
inputs = tokenizer_nllb(text, return_tensors="pt", max_length=max_length, truncation=True)
# Attention: certaines versions transformers n'ont pas lang_code_to_id
forced_bos_token_id = tokenizer_nllb.lang_code_to_id[tgt_lang] # Si erreur ici, il faut une autre version de transformers
generated_tokens = model_nllb.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=max_length
)
return tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|