Daenox's picture
Update app.py
1f4ced1 verified
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import gradio as gr
import re
import torch
model_path = "malteos/gpt2-uk"
# GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Завантаження моделі
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='right')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
model.config.pad_token_id = tokenizer.pad_token_id
# for GPT2
#оригiнальна модель та точки збереження
train_directory = "model"
train_file_path = "anekdoty.txt"
model_name = train_directory
#донавчена модель та точки збереження
output_dir = "model_tuned"
overwrite_output_dir = False
learning_rate=4e-4
per_device_train_batch_size=8
num_train_epochs=3
gradient_accumulation_steps=1,
warmup_steps=200,
weight_decay=0.01,
save_total_limit=1,
fp16=True,
report_to="none",
optim="adafactor",
lr_scheduler_type="linear",
dataloader_num_workers=2
save_total_limit=2
save_steps = 10000
def clean_text(text):
text = text.lower()
text = re.sub(r'<[^>]*>', '', text) # Видалення HTML-тегів
text = re.sub(r'[^a-zA-Z0-9\s]', '', text) # Видалення спеціальних символів
text = re.sub(r'\s+', ' ', text) # Видалення зайвих пробілів
return text.strip()
def load_dataset(file_path, tokenizer, block_size = 128):
dataset = TextDataset(
tokenizer = tokenizer,
file_path = file_path,
block_size = block_size,
)
return dataset
def load_data_collator(tokenizer, mlm = False):
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=mlm,
)
return data_collator
def train(train_file_path,model_name,
output_dir,
overwrite_output_dir,
per_device_train_batch_size,
num_train_epochs,
save_steps, resume_from_checkpoint):
#tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = load_dataset(train_file_path, tokenizer)
data_collator = load_data_collator(tokenizer)
tokenizer.save_pretrained(output_dir)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.save_pretrained(output_dir)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
per_device_train_batch_size=per_device_train_batch_size,
num_train_epochs=num_train_epochs,
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_model()
tokenizer.save_pretrained(train_directory)
model.save_pretrained(train_directory)
train(train_file_path=train_file_path,
model_name=model_name,
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
per_device_train_batch_size=per_device_train_batch_size,
num_train_epochs=num_train_epochs,
save_steps=save_steps,
# False для першого разу,
# True - для вiдновлення з точки зупинки
resume_from_checkpoint=False)
def create_joke(prompt, length=20, creativity=0.3, repeat_penalty=1.5):
try:
# Обмеження параметрів
length = min(max(int(length), 10), 200)
creativity = min(max(float(creativity), 0.1), 1.0)
repeat_penalty = min(max(float(repeat_penalty), 1.0), 2.0)
# Генерація тексту
result = joke_generator(
prompt,
max_new_tokens=length,
temperature=creativity,
repetition_penalty=repeat_penalty,
do_sample=True,
top_p=0.9,
no_repeat_ngram_size=2,
pad_token_id=50256,
early_stopping=True
)
# Обробка результату
if result and isinstance(result, list):
joke = result[0]['generated_text']
joke = joke.replace(prompt, "").strip()
# Обрізаємо до першого завершального знаку
for end_mark in ['.', '!', '?', '\n']:
if end_mark in joke:
joke = joke.split(end_mark)[0] + end_mark
break
return joke[:200] # Додаткове обмеження довжини
return "Не вдалося згенерувати текст"
except Exception as e:
return f"Сталася помилка: {str(e)}"
# Інтерфейс
with gr.Blocks(title="Генератор анекдотів") as app:
gr.Markdown("## Генератор анекдотів на основі GPT-2")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Початкова фраза", value="")
length_slider = gr.Slider(10, 200, value=20, step=1, label="Довжина відповіді")
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Температура")
penalty_slider = gr.Slider(1.0, 2.0, value=1.5, step=0.1, label="Штраф за повторення")
generate_button = gr.Button("Згенерувати")
with gr.Column():
output_box = gr.Textbox(label="Результат", lines=4)
generate_button.click(
fn=create_joke,
inputs=[text_input, length_slider, temp_slider, penalty_slider],
outputs=output_box
)
app.launch()