Spaces:
Build error
Build error
| import transformers | |
| import streamlit as st | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
| import torch | |
| st.title("Fine-tuned GPT-2 for New Language with Custom Tokenizer") | |
| # Слайдеры для управления температурой и длиной текста | |
| temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей | |
| max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей | |
| # Кеширование модели и токенизатора GPT-2 | |
| def load_gpt2(): | |
| model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2") | |
| tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2") | |
| return model_gpt2, tokenizer_gpt2 | |
| # Кеширование кастомной модели и токенизатора | |
| def load_custom_model(): | |
| # Здесь замените путь на вашу кастомную модель | |
| model_custom = GPT2LMHeadModel.from_pretrained("./rus_gpt2_tuned", from_tf=False, use_safetensors=True) | |
| tokenizer_custom = GPT2TokenizerFast.from_pretrained("./rus_gpt2_tuned/tokenizer") | |
| return model_custom, tokenizer_custom | |
| # Функция для генерации текста | |
| def generate_text(model, tokenizer, prompt, max_len, temperature): | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
| attention_mask = torch.ones_like(input_ids) | |
| # Генерация текста | |
| output = model.generate( | |
| input_ids, | |
| max_length=max_len, | |
| temperature=temperature, # Управление разнообразием текста | |
| top_k=50, # Ограничение топ-50 самых вероятных слов | |
| top_p=0.9, # Nucleus sampling (суммарная вероятность) | |
| repetition_penalty=1.2, # Штраф за повторение слов или фраз | |
| no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм) | |
| do_sample=True, # Включение сэмплинга для большей разнообразности | |
| attention_mask=attention_mask, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Декодирование сгенерированных токенов в текст | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| return generated_text | |
| # Streamlit приложение | |
| def main(): | |
| model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель | |
| model_custom, tokenizer_custom = load_custom_model() # Кастомная модель | |
| #st.write("Fine-tuned GPT-2 for New Language with Custom Tokenizer") | |
| # # Блок для генерации текста с GPT-2 | |
| # st.subheader("GPT-2 Text Generation") | |
| # prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века") | |
| # generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2") | |
| # if generate_button_gpt2: | |
| # generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature) | |
| # st.subheader("Результат генерации GPT-2:") | |
| # st.write(generated_text_gpt2) | |
| # Блок для генерации текста с кастомной моделью | |
| st.subheader("Custom Model Text Generation") | |
| prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно") | |
| generate_button_custom = st.button("Generate!") | |
| if generate_button_custom: | |
| generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature) | |
| st.subheader("Result:") | |
| st.write(generated_text_custom) | |
| if __name__ == "__main__": | |
| main() | |