Spaces:
Sleeping
Sleeping
| import torch | |
| import streamlit as st | |
| from colorama import Fore | |
| from core.models.gpt import GPTLanguageModel | |
| from core.tokenizers.tokenizer import Tokenizer | |
| from core.utils.gptutils import hyperparameters, load_data | |
| st.set_page_config(layout='wide', | |
| page_title='QuillGPT', | |
| page_icon='🪶', | |
| initial_sidebar_state='expanded' | |
| ) | |
| def decode_text(input, model: GPTLanguageModel, max_tokens, temperature): | |
| for idx in model.generate(idx=input, max_new_tokens=max_tokens, max_seq_length=50, temperature=temperature): | |
| text = tokenizer.decode(idx[0].tolist())[-1] | |
| yield text | |
| models = { | |
| "Shakespearean GPT": './weights/GPT_model_char.pt', | |
| } | |
| st.sidebar.header('QuillGPT') | |
| st.sidebar.write("This app generates text using a GPT model trained on either the Harpoon corpus or Shakespearean plays.") | |
| # Select one of the two model | |
| model_name = st.sidebar.selectbox('Select a model:', list(models.keys())) | |
| if model_name == "GPT": | |
| st.title('GPT From Scratch') | |
| st.write("This model was trained on the Harpoon corpus.") | |
| else: | |
| st.title('Shakespearean GPT') | |
| st.write("This model was trained on Shakespearean plays.") | |
| path = models[model_name] | |
| if model_name == "GPT": | |
| config_path = './config/harpoon_config.json' | |
| data_path = './data/corpus.txt' | |
| name = "Harpoon GPT" | |
| tokenizer: Tokenizer = Tokenizer() | |
| tokenizer.from_pretrained(config_path) | |
| vocab_size = tokenizer.vocab_size | |
| (batch_size, block_size, max_iters, eval_interval, learning_rate, device, | |
| eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path) | |
| elif model_name == "Shakespearean GPT": | |
| config_path = './config/shakespearean_config.json' | |
| data_path = './data/input.txt' | |
| name = "Shakespearean GPT" | |
| tokenizer: Tokenizer = Tokenizer() | |
| tokenizer.from_pretrained(config_path) | |
| vocab_size = tokenizer.vocab_size | |
| (batch_size, block_size, max_iters, eval_interval, learning_rate, device, | |
| eval_iters, n_embd, n_head, n_layer, dropout) = hyperparameters(config_path=config_path) | |
| if model_name == "GPT": | |
| input_text = st.text_area( | |
| 'Enter a prompt:', 'And then Ted said, "' | |
| ) | |
| else: | |
| input_text = st.text_area( | |
| 'Enter a prompt:', 'Write a scene about ROMEO arguing with JULIET. \nROMEO:' | |
| ) | |
| temperature = st.sidebar.slider('Temperature:', 0.1, 1.0, 0.5, 0.1) | |
| max_tokens = st.sidebar.slider('Max Tokens:', 250, 1000, 500, 50) | |
| def load_model(path): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| try: | |
| model = GPTLanguageModel( | |
| vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name=name | |
| ).to(device) | |
| state_dict = torch.load( | |
| path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| return model, device | |
| except FileNotFoundError as e: | |
| st.error(f"Don't forget to download the model weights from the link in the README.md file.") | |
| return None, None | |
| model, device = load_model(path) | |
| if model: | |
| if st.button('Generate Text'): | |
| prompt = input_text | |
| st.subheader(model.name) | |
| input = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) | |
| generated_text = [] | |
| st.write(f":green[{prompt}]") | |
| st.write_stream(decode_text(input, model, max_tokens, temperature)) | |