svilanovich's picture
Update pages/ruGPT.py
d15958c verified
import streamlit as st
from transformers import (GPT2LMHeadModel,
GPT2Tokenizer,
Trainer,
TrainingArguments,
TextDataset,
DataCollatorForLanguageModeling)
import torch
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def generate_text(text: str,
temperature: float,
top_k: int,
top_p: float,
num_beams: int,
max_length: int):
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
time_start = time.time()
nuc_out = model.generate(input_ids,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
max_length=max_length)
response_time = time.time() - time_start
# Декодирование токенов
nuc_generated_text = list(map(tokenizer.decode, nuc_out))[0]
return nuc_generated_text, response_time
paath = 'data/finetuned/'
model = GPT2LMHeadModel.from_pretrained(paath).to(device)
model_name_or_path = "ai-forever/rugpt3small_based_on_gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
st.markdown("""
<style>
.title {
font-family: 'Arial Black', sans-serif;
font-size: 36px;
color: #1E90FF;
text-align: center;
margin-bottom: 30px;
}
.sidebar .sidebar-content {
background-image: linear-gradient(#2e7bcf,#2e7bcf);
color: white;
}
.stButton>button {
color: white;
background: linear-gradient(to right, #1fa2ff, #12d8fa, #a6ffcb);
}
.uploaded-image {
border: 5px solid #1E90FF;
border-radius: 10px;
margin-bottom: 15px;
}
</style>
""", unsafe_allow_html=True)
# Отображение заголовка
st.markdown('<h1 class="title">Вещий Олег</h1>', unsafe_allow_html=True)
text = st.text_input(label='Введите текст')
temperature = float(st.slider('Температура', min_value=1.0, max_value=3.0, value=1.5, step=0.1))
top_k = int(st.slider('Топ k токенов', min_value=5, max_value=500, value=200, step=5))
top_p = float(st.slider('Минимальная вероятность токенов', min_value=0.1, max_value=1.0, value=0.6, step=0.05))
num_beams = int(st.slider('НУМ БИМС))', min_value=1, max_value=20, value=5, step=1))
max_length = int(st.slider('Длина ответа в токенах', min_value=1, max_value=200, value=100, step=1))
if text:
gen_t, response_time = generate_text(text=text,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
max_length=max_length)
st.write(gen_t)
st.write(response_time)