import gradio as gr import os import json from datasets import Dataset from transformers import ( MarianMTModel, MarianTokenizer, T5ForConditionalGeneration, T5Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer ) import torch # Безопасное создание папки if not os.path.isdir("models"): try: os.mkdir("models") except: pass # ----------- LOAD MODELS ----------- BASE_MODELS = { "MarianMT ru→en": "Helsinki-NLP/opus-mt-ru-en", "MarianMT en→ru": "Helsinki-NLP/opus-mt-en-ru", "T5-small ru→en": "t5-small", "T5-small en→ru": "t5-small" } def load_model(model_id): if "Marian" in model_id: tokenizer = MarianTokenizer.from_pretrained(model_id) model = MarianMTModel.from_pretrained(model_id) else: tokenizer = T5Tokenizer.from_pretrained(model_id) model = T5ForConditionalGeneration.from_pretrained(model_id) return model, tokenizer # ----------- TRAINING FUNCTION ----------- def train_model(base_model_name, train_file, num_epochs, batch_size): # load dataset data = train_file.decode("utf-8").split("\n") pairs = [l.split("\t") for l in data if "\t" in l] ds = Dataset.from_dict({ "src": [p[0] for p in pairs], "trg": [p[1] for p in pairs] }) # load pretrained model_id = BASE_MODELS[base_model_name] model, tokenizer = load_model(model_id) # preprocess function def preprocess(batch): if "Marian" in base_model_name: inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128) with tokenizer.as_target_tokenizer(): labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128) inputs["labels"] = labels["input_ids"] return inputs else: # T5 prefix = "translate Russian to English: " if "ru→en" in base_model_name else "translate English to Russian: " inputs = tokenizer(prefix + batch["src"], truncation=True, padding="max_length", max_length=128) with tokenizer.as_target_tokenizer(): labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128) inputs["labels"] = labels["input_ids"] return inputs tokenized = ds.map(preprocess, batched=True) # training args args = Seq2SeqTrainingArguments( output_dir="models", metric_for_best_model="loss", save_strategy="no", num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, learning_rate=2e-4, logging_steps=5, report_to="none", ) collator = DataCollatorForSeq2Seq(tokenizer, model=model) trainer = Seq2SeqTrainer( model=model, args=args, train_dataset=tokenized, data_collator=collator, ) trainer.train() # SAVE save_path = f"models/{base_model_name.replace(' ', '_')}" model.save_pretrained(save_path) tokenizer.save_pretrained(save_path) return f"Модель сохранена в {save_path}" # ----------- TRANSLATION ----------- def translate(text, model_name): model_path = f"models/{model_name.replace(' ', '_')}" if not os.path.exists(model_path): return "Сначала обучите модель." if "Marian" in model_name: tokenizer = MarianTokenizer.from_pretrained(model_path) model = MarianMTModel.from_pretrained(model_path) else: tokenizer = T5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) if "T5-small" in model_name: prefix = "translate Russian to English: " if "ru→en" in model_name else "translate English to Russian: " input_ids = tokenizer(prefix + text, return_tensors="pt").input_ids out = model.generate(input_ids, max_length=200) return tokenizer.decode(out[0], skip_special_tokens=True) else: # Marian enc = tokenizer([text], return_tensors="pt") out = model.generate(**enc) return tokenizer.decode(out[0], skip_special_tokens=True) # ----------- GRADIO UI ----------- with gr.Blocks() as demo: gr.Markdown("# 🚀 Обучение переводчика (MarianMT / T5-small)") with gr.Tab("Обучение"): base_model = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите модель") train_data = gr.File(label="Загрузите тренировочный датасет (формат: srctgt)") epochs = gr.Slider(1, 5, value=1, step=1, label="Эпохи") batch = gr.Slider(1, 16, value=4, step=1, label="Батч") train_button = gr.Button("Начать обучение") train_output = gr.Textbox(label="Логи") train_button.click( train_model, inputs=[base_model, train_data, epochs, batch], outputs=train_output ) with gr.Tab("Перевод"): model_choice = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите обученную модель") text = gr.Textbox(lines=5, label="Введите текст") translate_button = gr.Button("Перевести") translation_result = gr.Textbox(label="Перевод") translate_button.click(translate, [model_choice, text], translation_result) demo.launch()