Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Загрузка модели и токенизатора | |
| def load_model(): | |
| model_name = "models/gpt" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return model, tokenizer | |
| def generate_text(model, tokenizer, prompt, gen_params): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=gen_params['max_length'], | |
| temperature=gen_params['temperature'], | |
| top_k=gen_params['top_k'], | |
| top_p=gen_params['top_p'], | |
| num_return_sequences=gen_params['num_return_sequences'], | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = [] | |
| for i, output in enumerate(outputs): | |
| text = tokenizer.decode(output, skip_special_tokens=True) | |
| generated.append(f"Генерация {i+1}:\n{text}\n{'-'*50}") | |
| return generated | |
| def main(): | |
| st.markdown( | |
| "<h1 style='text-align: center;'>Генератор текста</h1>", | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| "<h3 style='text-align: center;'>(ну почти)</h3>", | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown("---") | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| st.image('images/scale_1200.png', width=500) | |
| # Загрузка модели | |
| model, tokenizer = load_model() | |
| # Параметры генерации | |
| with st.sidebar: | |
| st.header("Настройки генерации") | |
| prompt = st.text_area("Введите начальный текст:", height=100) | |
| max_length = st.slider("Максимальная длина:", 50, 500, 100) | |
| num_return_sequences = st.slider("Число генераций:", 1, 5, 1) | |
| st.subheader("Параметры выборки:") | |
| sampling_method = st.radio("Метод:", ["Temperature", "Top-k & Top-p"]) | |
| if sampling_method == "Temperature": | |
| temperature = st.slider("Temperature:", 0.1, 2.0, 1.0, 0.1) | |
| top_k = None | |
| top_p = None | |
| else: | |
| temperature = 1.0 | |
| top_k = st.slider("Top-k:", 1, 100, 50) | |
| top_p = st.slider("Top-p:", 0.1, 1.0, 0.9, 0.05) | |
| # Кнопка генерации | |
| if st.sidebar.button("Сгенерировать текст"): | |
| if not prompt: | |
| st.warning("Введите начальный текст!") | |
| return | |
| gen_params = { | |
| 'max_length': max_length, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'num_return_sequences': num_return_sequences | |
| } | |
| with st.spinner("Прибухиваем..."): | |
| generated = generate_text(model, tokenizer, prompt, gen_params) | |
| st.markdown("---") | |
| st.subheader("Результаты:") | |
| for text in generated: | |
| st.text_area(label="", value=text, height=200) | |
| if __name__ == "__main__": | |
| main() |