Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import os | |
| from GPTLanguageModelClass import * | |
| block_size = hyperparams.block_size | |
| batch_size = hyperparams.batch_size | |
| max_iters = hyperparams.max_iters | |
| learning_rate = hyperparams.learning_rate | |
| eval_every = hyperparams.eval_every | |
| n_embd = hyperparams.n_embd | |
| n_head = hyperparams.n_head | |
| n_layer = hyperparams.n_layer | |
| dropout = hyperparams.dropout | |
| device = hyperparams.device | |
| st.title('LLM from scratch Demo') | |
| st.write(f"Using device: {device}") | |
| if not os.path.exists("./vocab.txt"): | |
| raise Exception("Please run extract.py first") | |
| chars = "" | |
| with open("./vocab.txt", 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| chars = sorted(list(set(text))) | |
| string_to_int = {ch: i for i, ch in enumerate(chars)} | |
| int_to_string = {i: ch for i, ch in enumerate(chars)} | |
| encode = lambda s: [string_to_int[ch] for ch in s] | |
| decode = lambda x: ''.join([int_to_string[i] for i in x]) | |
| model_pickle_path = './model.pt' | |
| st.write('loading model parameters...') | |
| with open(model_pickle_path, 'rb') as f: | |
| model = torch.load(f, map_location=device) | |
| st.write('model loaded successfully!') | |
| prompt = '' | |
| prompt = st.text_area('Prompt:', value=prompt, height=100, max_chars=block_size - 1, key='prompt') | |
| if len(prompt) != 0: | |
| context = torch.tensor(encode(prompt), dtype=torch.long, device=device) | |
| max_new_tokens = block_size - len(prompt) | |
| generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=max_new_tokens)[0].tolist()) | |
| st.write('Generated text:') | |
| st.write(generated_chars) | |