Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model import BigramLanguageModel, ModelConfig | |
| with open('input.txt', 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| # unique characters from text | |
| chars = sorted(list(set(text))) | |
| vocab_size = len(chars) | |
| # create a mapping from chars to ints | |
| stoi = { ch:i for i,ch in enumerate(chars) } | |
| itos = { i:ch for i,ch in enumerate(chars) } | |
| encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers | |
| decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string | |
| # load model | |
| model_spec = torch.load("ckpt.pt", map_location=torch.device('cpu')) | |
| model_args = model_spec['model_args'] | |
| model_weights = model_spec['model'] | |
| modelconf = ModelConfig(**model_args) | |
| trained_model = BigramLanguageModel(modelconf) | |
| trained_model.load_state_dict(model_weights) | |
| def generate_text(seed_text, max_new_tokens, confidence_val): | |
| text = seed_text if seed_text is not None else " " | |
| text = text if text.endswith(" ") else seed_text + " " | |
| context = torch.tensor(encode(text), dtype=torch.long).unsqueeze(0) | |
| confidence_val = confidence_val if confidence_val > 0 else 1e-5 | |
| return decode(trained_model.generate(context, temperature = confidence_val, max_new_tokens=max_new_tokens)[0].tolist()) | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h1 align = 'center'> Simple GPT from scratch using tiny Shakespere </h1>") | |
| content = gr.Textbox(label = "Initial text to generate content") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| max_tokens = gr.Number(label = "Maximum number of tokens", value = 100) | |
| confidence_val = gr.Slider(label = "Confidence", minimum = 0.0, maximum= 1.0,value = 0.7) | |
| generate_btn = gr.Button(value = 'Generate Text') | |
| with gr.Column(): | |
| outputs = [gr.TextArea(label = "Generated result", lines = 8)] | |
| inputs = [ | |
| content, | |
| max_tokens, | |
| confidence_val | |
| ] | |
| generate_btn.click(fn = generate_text, inputs= inputs, outputs = outputs) | |
| if __name__ == '__main__': | |
| demo.launch() | |