student2222333051 commited on
Commit
59206fe
·
verified ·
1 Parent(s): 055e8ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py CHANGED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ from datasets import Dataset
5
+ from transformers import (
6
+ MarianMTModel, MarianTokenizer,
7
+ T5ForConditionalGeneration, T5Tokenizer,
8
+ DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
9
+ )
10
+ import torch
11
+
12
+ os.makedirs("models", exist_ok=True)
13
+
14
+ # ----------- LOAD MODELS -----------
15
+
16
+ BASE_MODELS = {
17
+ "MarianMT ru→en": "Helsinki-NLP/opus-mt-ru-en",
18
+ "MarianMT en→ru": "Helsinki-NLP/opus-mt-en-ru",
19
+ "T5-small ru→en": "t5-small",
20
+ "T5-small en→ru": "t5-small"
21
+ }
22
+
23
+ def load_model(model_id):
24
+ if "Marian" in model_id:
25
+ tokenizer = MarianTokenizer.from_pretrained(model_id)
26
+ model = MarianMTModel.from_pretrained(model_id)
27
+ else:
28
+ tokenizer = T5Tokenizer.from_pretrained(model_id)
29
+ model = T5ForConditionalGeneration.from_pretrained(model_id)
30
+ return model, tokenizer
31
+
32
+ # ----------- TRAINING FUNCTION -----------
33
+
34
+ def train_model(base_model_name, train_file, num_epochs, batch_size):
35
+
36
+ # load dataset
37
+ data = train_file.decode("utf-8").split("\n")
38
+ pairs = [l.split("\t") for l in data if "\t" in l]
39
+
40
+ ds = Dataset.from_dict({
41
+ "src": [p[0] for p in pairs],
42
+ "trg": [p[1] for p in pairs]
43
+ })
44
+
45
+ # load pretrained
46
+ model_id = BASE_MODELS[base_model_name]
47
+ model, tokenizer = load_model(model_id)
48
+
49
+ # preprocess function
50
+ def preprocess(batch):
51
+ if "Marian" in base_model_name:
52
+ inputs = tokenizer(batch["src"], truncation=True, padding="max_length", max_length=128)
53
+ with tokenizer.as_target_tokenizer():
54
+ labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128)
55
+ inputs["labels"] = labels["input_ids"]
56
+ return inputs
57
+ else: # T5
58
+ prefix = "translate Russian to English: " if "ru→en" in base_model_name else "translate English to Russian: "
59
+ inputs = tokenizer(prefix + batch["src"], truncation=True, padding="max_length", max_length=128)
60
+ with tokenizer.as_target_tokenizer():
61
+ labels = tokenizer(batch["trg"], truncation=True, padding="max_length", max_length=128)
62
+ inputs["labels"] = labels["input_ids"]
63
+ return inputs
64
+
65
+ tokenized = ds.map(preprocess, batched=True)
66
+
67
+ # training args
68
+ args = Seq2SeqTrainingArguments(
69
+ output_dir="models",
70
+ metric_for_best_model="loss",
71
+ save_strategy="no",
72
+ num_train_epochs=num_epochs,
73
+ per_device_train_batch_size=batch_size,
74
+ learning_rate=2e-4,
75
+ logging_steps=5,
76
+ report_to="none",
77
+ )
78
+
79
+ collator = DataCollatorForSeq2Seq(tokenizer, model=model)
80
+ trainer = Seq2SeqTrainer(
81
+ model=model,
82
+ args=args,
83
+ train_dataset=tokenized,
84
+ data_collator=collator,
85
+ )
86
+
87
+ trainer.train()
88
+
89
+ # SAVE
90
+ save_path = f"models/{base_model_name.replace(' ', '_')}"
91
+ model.save_pretrained(save_path)
92
+ tokenizer.save_pretrained(save_path)
93
+
94
+ return f"Модель сохранена в {save_path}"
95
+
96
+ # ----------- TRANSLATION -----------
97
+
98
+ def translate(text, model_name):
99
+ model_path = f"models/{model_name.replace(' ', '_')}"
100
+ if not os.path.exists(model_path):
101
+ return "Сначала обучите модель."
102
+
103
+ if "Marian" in model_name:
104
+ tokenizer = MarianTokenizer.from_pretrained(model_path)
105
+ model = MarianMTModel.from_pretrained(model_path)
106
+ else:
107
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
108
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
109
+
110
+ if "T5-small" in model_name:
111
+ prefix = "translate Russian to English: " if "ru→en" in model_name else "translate English to Russian: "
112
+ input_ids = tokenizer(prefix + text, return_tensors="pt").input_ids
113
+ out = model.generate(input_ids, max_length=200)
114
+ return tokenizer.decode(out[0], skip_special_tokens=True)
115
+
116
+ else: # Marian
117
+ enc = tokenizer([text], return_tensors="pt")
118
+ out = model.generate(**enc)
119
+ return tokenizer.decode(out[0], skip_special_tokens=True)
120
+
121
+
122
+ # ----------- GRADIO UI -----------
123
+
124
+ with gr.Blocks() as demo:
125
+
126
+ gr.Markdown("# 🚀 Обучение переводчика (MarianMT / T5-small)")
127
+
128
+ with gr.Tab("Обучение"):
129
+ base_model = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите модель")
130
+ train_data = gr.File(label="Загрузите тренировочный датасет (формат: src<TAB>tgt)")
131
+ epochs = gr.Slider(1, 5, value=1, step=1, label="Эпохи")
132
+ batch = gr.Slider(1, 16, value=4, step=1, label="Батч")
133
+
134
+ train_button = gr.Button("Начать обучение")
135
+ train_output = gr.Textbox(label="Логи")
136
+
137
+ train_button.click(
138
+ train_model,
139
+ inputs=[base_model, train_data, epochs, batch],
140
+ outputs=train_output
141
+ )
142
+
143
+ with gr.Tab("Перевод"):
144
+ model_choice = gr.Dropdown(list(BASE_MODELS.keys()), label="Выберите обученную модель")
145
+ text = gr.Textbox(lines=5, label="Введите текст")
146
+ translate_button = gr.Button("Перевести")
147
+ translation_result = gr.Textbox(label="Перевод")
148
+
149
+ translate_button.click(translate, [model_choice, text], translation_result)
150
+
151
+ demo.launch()