Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import (GPT2LMHeadModel, | |
| GPT2Tokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| TextDataset, | |
| DataCollatorForLanguageModeling) | |
| import torch | |
| import time | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def generate_text(text: str, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| num_beams: int, | |
| max_length: int): | |
| input_ids = tokenizer.encode(text, return_tensors="pt").to(device) | |
| time_start = time.time() | |
| nuc_out = model.generate(input_ids, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| max_length=max_length) | |
| response_time = time.time() - time_start | |
| # Декодирование токенов | |
| nuc_generated_text = list(map(tokenizer.decode, nuc_out))[0] | |
| return nuc_generated_text, response_time | |
| paath = 'data/finetuned/' | |
| model = GPT2LMHeadModel.from_pretrained(paath).to(device) | |
| model_name_or_path = "ai-forever/rugpt3small_based_on_gpt2" | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) | |
| st.markdown(""" | |
| <style> | |
| .title { | |
| font-family: 'Arial Black', sans-serif; | |
| font-size: 36px; | |
| color: #1E90FF; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| .sidebar .sidebar-content { | |
| background-image: linear-gradient(#2e7bcf,#2e7bcf); | |
| color: white; | |
| } | |
| .stButton>button { | |
| color: white; | |
| background: linear-gradient(to right, #1fa2ff, #12d8fa, #a6ffcb); | |
| } | |
| .uploaded-image { | |
| border: 5px solid #1E90FF; | |
| border-radius: 10px; | |
| margin-bottom: 15px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Отображение заголовка | |
| st.markdown('<h1 class="title">Вещий Олег</h1>', unsafe_allow_html=True) | |
| text = st.text_input(label='Введите текст') | |
| temperature = float(st.slider('Температура', min_value=1.0, max_value=3.0, value=1.5, step=0.1)) | |
| top_k = int(st.slider('Топ k токенов', min_value=5, max_value=500, value=200, step=5)) | |
| top_p = float(st.slider('Минимальная вероятность токенов', min_value=0.1, max_value=1.0, value=0.6, step=0.05)) | |
| num_beams = int(st.slider('НУМ БИМС))', min_value=1, max_value=20, value=5, step=1)) | |
| max_length = int(st.slider('Длина ответа в токенах', min_value=1, max_value=200, value=100, step=1)) | |
| if text: | |
| gen_t, response_time = generate_text(text=text, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| max_length=max_length) | |
| st.write(gen_t) | |
| st.write(response_time) | |