project3 / app.py
student2222333051's picture
Update app.py
3514219 verified
import gradio as gr
import os
import json
from datasets import Dataset
from transformers import (
MarianMTModel, MarianTokenizer,
T5ForConditionalGeneration, T5Tokenizer,
DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)
import torch
# Безопасное создание папки
if not os.path.isdir("models"):
try:
os.mkdir("models")
except:
pass
# ----------- LOAD MODELS -----------
BASE_MODELS = {
"MarianMT ru→en": "Helsinki-NLP/opus-mt-ru-en",
"MarianMT en→ru": "Helsinki-NLP/opus-mt-en-ru",
"T5-small ru→en": "t5-small",
"T5-small en→ru": "t5-small"
}
def load_model(model_id):
if "Marian" in model_id:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
else:
tokenizer = T5Tokenizer.from_pretrained(model_id)
model = T5ForConditionalGeneration.from_pretrained(model_id)
return model, tokenizer
# ----------- TRAINING FUNCTION -----------
def train_model(base_model_name, train_file, num_epochs, batch_size):
# load dataset
data = train_file.decode("utf-8").split("\n")
pairs = [l.split("\t") for l in data if "\t" in l]
ds = Dataset.from_dict({
"src": [p[0] for p in pairs],
"trg": [p[1] for p in pairs]
})
# load pretrained
model_id = BASE_MODELS[base_model_name]
model, tokenizer = load_model(model_id)
# preprocess function
def preprocess(batch):
if "Marian" in base_model_name:
inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128)
with tokenizer.as_target_tokenizer():
labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128)
inputs["labels"] = labels["input_ids"]
return inputs
else: # T5
prefix = "translate Russian to English: " if "ru→en" in base_model_name else "translate English to Russian: "
inputs = tokenizer(prefix + batch["src"], truncation=True, padding="max_length", max_length=128)
with tokenizer.as_target_tokenizer():
labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128)
inputs["labels"] = labels["input_ids"]
return inputs
tokenized = ds.map(preprocess, batched=True)
# training args
args = Seq2SeqTrainingArguments(
output_dir="models",
metric_for_best_model="loss",
save_strategy="no",
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=2e-4,
logging_steps=5,
report_to="none",
)
collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=tokenized,
data_collator=collator,
)
trainer.train()
# SAVE
save_path = f"models/{base_model_name.replace(' ', '_')}"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"Модель сохранена в {save_path}"
# ----------- TRANSLATION -----------
def translate(text, model_name):
model_path = f"models/{model_name.replace(' ', '_')}"
if not os.path.exists(model_path):
return "Сначала обучите модель."
if "Marian" in model_name:
tokenizer = MarianTokenizer.from_pretrained(model_path)
model = MarianMTModel.from_pretrained(model_path)
else:
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
if "T5-small" in model_name:
prefix = "translate Russian to English: " if "ru→en" in model_name else "translate English to Russian: "
input_ids = tokenizer(prefix + text, return_tensors="pt").input_ids
out = model.generate(input_ids, max_length=200)
return tokenizer.decode(out[0], skip_special_tokens=True)
else: # Marian
enc = tokenizer([text], return_tensors="pt")
out = model.generate(**enc)
return tokenizer.decode(out[0], skip_special_tokens=True)
# ----------- GRADIO UI -----------
with gr.Blocks() as demo:
gr.Markdown("# 🚀 Обучение переводчика (MarianMT / T5-small)")
with gr.Tab("Обучение"):
base_model = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите модель")
train_data = gr.File(label="Загрузите тренировочный датасет (формат: src<TAB>tgt)")
epochs = gr.Slider(1, 5, value=1, step=1, label="Эпохи")
batch = gr.Slider(1, 16, value=4, step=1, label="Батч")
train_button = gr.Button("Начать обучение")
train_output = gr.Textbox(label="Логи")
train_button.click(
train_model,
inputs=[base_model, train_data, epochs, batch],
outputs=train_output
)
with gr.Tab("Перевод"):
model_choice = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите обученную модель")
text = gr.Textbox(lines=5, label="Введите текст")
translate_button = gr.Button("Перевести")
translation_result = gr.Textbox(label="Перевод")
translate_button.click(translate, [model_choice, text], translation_result)
demo.launch()