Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| from datasets import Dataset | |
| from transformers import ( | |
| MarianMTModel, MarianTokenizer, | |
| T5ForConditionalGeneration, T5Tokenizer, | |
| DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer | |
| ) | |
| import torch | |
| # Безопасное создание папки | |
| if not os.path.isdir("models"): | |
| try: | |
| os.mkdir("models") | |
| except: | |
| pass | |
| # ----------- LOAD MODELS ----------- | |
| BASE_MODELS = { | |
| "MarianMT ru→en": "Helsinki-NLP/opus-mt-ru-en", | |
| "MarianMT en→ru": "Helsinki-NLP/opus-mt-en-ru", | |
| "T5-small ru→en": "t5-small", | |
| "T5-small en→ru": "t5-small" | |
| } | |
| def load_model(model_id): | |
| if "Marian" in model_id: | |
| tokenizer = MarianTokenizer.from_pretrained(model_id) | |
| model = MarianMTModel.from_pretrained(model_id) | |
| else: | |
| tokenizer = T5Tokenizer.from_pretrained(model_id) | |
| model = T5ForConditionalGeneration.from_pretrained(model_id) | |
| return model, tokenizer | |
| # ----------- TRAINING FUNCTION ----------- | |
| def train_model(base_model_name, train_file, num_epochs, batch_size): | |
| # load dataset | |
| data = train_file.decode("utf-8").split("\n") | |
| pairs = [l.split("\t") for l in data if "\t" in l] | |
| ds = Dataset.from_dict({ | |
| "src": [p[0] for p in pairs], | |
| "trg": [p[1] for p in pairs] | |
| }) | |
| # load pretrained | |
| model_id = BASE_MODELS[base_model_name] | |
| model, tokenizer = load_model(model_id) | |
| # preprocess function | |
| def preprocess(batch): | |
| if "Marian" in base_model_name: | |
| inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128) | |
| inputs["labels"] = labels["input_ids"] | |
| return inputs | |
| else: # T5 | |
| prefix = "translate Russian to English: " if "ru→en" in base_model_name else "translate English to Russian: " | |
| inputs = tokenizer(prefix + batch["src"], truncation=True, padding="max_length", max_length=128) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128) | |
| inputs["labels"] = labels["input_ids"] | |
| return inputs | |
| tokenized = ds.map(preprocess, batched=True) | |
| # training args | |
| args = Seq2SeqTrainingArguments( | |
| output_dir="models", | |
| metric_for_best_model="loss", | |
| save_strategy="no", | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=2e-4, | |
| logging_steps=5, | |
| report_to="none", | |
| ) | |
| collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=tokenized, | |
| data_collator=collator, | |
| ) | |
| trainer.train() | |
| # SAVE | |
| save_path = f"models/{base_model_name.replace(' ', '_')}" | |
| model.save_pretrained(save_path) | |
| tokenizer.save_pretrained(save_path) | |
| return f"Модель сохранена в {save_path}" | |
| # ----------- TRANSLATION ----------- | |
| def translate(text, model_name): | |
| model_path = f"models/{model_name.replace(' ', '_')}" | |
| if not os.path.exists(model_path): | |
| return "Сначала обучите модель." | |
| if "Marian" in model_name: | |
| tokenizer = MarianTokenizer.from_pretrained(model_path) | |
| model = MarianMTModel.from_pretrained(model_path) | |
| else: | |
| tokenizer = T5Tokenizer.from_pretrained(model_path) | |
| model = T5ForConditionalGeneration.from_pretrained(model_path) | |
| if "T5-small" in model_name: | |
| prefix = "translate Russian to English: " if "ru→en" in model_name else "translate English to Russian: " | |
| input_ids = tokenizer(prefix + text, return_tensors="pt").input_ids | |
| out = model.generate(input_ids, max_length=200) | |
| return tokenizer.decode(out[0], skip_special_tokens=True) | |
| else: # Marian | |
| enc = tokenizer([text], return_tensors="pt") | |
| out = model.generate(**enc) | |
| return tokenizer.decode(out[0], skip_special_tokens=True) | |
| # ----------- GRADIO UI ----------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🚀 Обучение переводчика (MarianMT / T5-small)") | |
| with gr.Tab("Обучение"): | |
| base_model = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите модель") | |
| train_data = gr.File(label="Загрузите тренировочный датасет (формат: src<TAB>tgt)") | |
| epochs = gr.Slider(1, 5, value=1, step=1, label="Эпохи") | |
| batch = gr.Slider(1, 16, value=4, step=1, label="Батч") | |
| train_button = gr.Button("Начать обучение") | |
| train_output = gr.Textbox(label="Логи") | |
| train_button.click( | |
| train_model, | |
| inputs=[base_model, train_data, epochs, batch], | |
| outputs=train_output | |
| ) | |
| with gr.Tab("Перевод"): | |
| model_choice = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите обученную модель") | |
| text = gr.Textbox(lines=5, label="Введите текст") | |
| translate_button = gr.Button("Перевести") | |
| translation_result = gr.Textbox(label="Перевод") | |
| translate_button.click(translate, [model_choice, text], translation_result) | |
| demo.launch() | |