Spaces:
Sleeping
Sleeping
| from fairseq.models.transformer import TransformerModel | |
| import torch | |
| import re | |
| import string | |
| class Translator: | |
| def __init__(self, isFon:bool, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
| # Charger le modèle pré-entraîné avec Fairseq | |
| inner = "fon_fr" if isFon else "fr_fon" | |
| self.model = TransformerModel.from_pretrained( | |
| f'./utils/checkpoints/{inner}', | |
| checkpoint_file = 'checkpoint_best.pt', | |
| data_name_or_path = f'utils/datas/data_prepared_{inner}/', | |
| source_lang='fon' if isFon else 'fr', | |
| target_lang='fr' if isFon else 'fon' | |
| ) | |
| # Définir le périphérique sur lequel exécuter le modèle (par défaut sur 'cuda' si disponible) | |
| self.model.to(device) | |
| # Mettre le modèle en mode évaluation (pas de mise à jour des poids) | |
| self.model.eval() | |
| def preprocess(self, data): | |
| print('Preprocessing...') | |
| # Convertir chaque lettre en minuscule | |
| text = data.lower().strip() | |
| # Supprimer les apostrophes des phrases | |
| text = re.sub("'", "", text) | |
| # Supprimer toute ponctuation | |
| exclude = set(string.punctuation) | |
| text = ''.join(ch for ch in text if ch not in exclude) | |
| # Supprimer les chiffres | |
| digit = str.maketrans('', '', string.digits) | |
| text = text.translate(digit) | |
| return text | |
| def translate(self, text): | |
| print(text) | |
| pre_traited = self.preprocess(text) | |
| print(pre_traited) | |
| # Encodage du texte en tokens | |
| tokens = self.model.encode(pre_traited) | |
| # Utilisation de la méthode generate avec le paramètre beam | |
| translations = self.model.generate(tokens, beam=5) | |
| print(type(translations)) | |
| print(translations[0]) | |
| best_translation_tokens = [translations[i]['tokens'].tolist() for i in range(5)] | |
| # Décodage des tokens en traduction | |
| translations = [self.model.decode(best_translation_tokens[i]) for i in range(5)] | |
| return "\n".join(translations) |