| import typing as tp | |
| from collections import namedtuple | |
| from functools import partial | |
| import torch | |
| from transformers import pipeline | |
| def get_translator(): | |
| return pipeline( | |
| "translation_en_to_ru", | |
| model="Helsinki-NLP/opus-mt-ru-en", | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| torch_dtype="auto", | |
| ) | |
| class TranslationModel: | |
| def __init__(self, get_model): | |
| self.translator = get_translator() | |
| self.model = get_model() | |
| def __call__(self, input, **kwargs): | |
| def transform_input_dict_to_str(input): | |
| if isinstance(input, tp.Dict): | |
| return input["authors"] + " " + input["abstract"] + " " + input["title"] | |
| if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict): | |
| input = [input] | |
| input = [transform_input_dict_to_str(i) for i in input] | |
| translated_input = self.translator(input) | |
| translated = [ | |
| translated_i["translation_text"] for translated_i in translated_input | |
| ] | |
| out = self.model(translated) | |
| if 1 == len(out): | |
| return out[0] | |
| return out | |
| def create_translation_models(models): | |
| return { | |
| f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model) | |
| for name, get_model in models.items() | |
| } | |