File size: 2,184 Bytes
bfa7230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51bfd98
bfa7230
 
 
 
 
 
890df21
bfa7230
 
 
890df21
74a7e2a
bfa7230
 
d579a9f
bfa7230
d579a9f
bfa7230
 
d579a9f
890df21
bfa7230
 
d579a9f
bfa7230
 
 
890df21
bfa7230
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from model import BigramLanguageModel, ModelConfig

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# unique characters from text
chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from chars to ints
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# load model
model_spec = torch.load("ckpt.pt", map_location=torch.device('cpu'))
model_args = model_spec['model_args']
model_weights = model_spec['model']
modelconf = ModelConfig(**model_args)
trained_model = BigramLanguageModel(modelconf)
trained_model.load_state_dict(model_weights)

def generate_text(seed_text, max_new_tokens, confidence_val):
    text = seed_text if seed_text is not None else " "
    text = text if text.endswith(" ") else seed_text + " "
    context = torch.tensor(encode(text), dtype=torch.long).unsqueeze(0)
    confidence_val = confidence_val if confidence_val > 0 else 1e-5
    return decode(trained_model.generate(context, temperature = confidence_val, max_new_tokens=max_new_tokens)[0].tolist())

with gr.Blocks() as demo:
    gr.HTML("<h1 align = 'center'> Simple GPT from scratch using tiny Shakespere </h1>")
    
    content = gr.Textbox(label = "Initial text to generate content")
    with gr.Row(equal_height=True):
      with gr.Column(): 
        max_tokens = gr.Number(label = "Maximum number of tokens", value = 100)
        confidence_val = gr.Slider(label = "Confidence", minimum = 0.0, maximum= 1.0,value = 0.7)
        generate_btn = gr.Button(value = 'Generate Text')
      with gr.Column(): 
        outputs  = [gr.TextArea(label = "Generated result", lines = 8)]
    inputs = [
                content,
                max_tokens, 
                confidence_val
              ]
    generate_btn.click(fn = generate_text, inputs= inputs, outputs = outputs)

if __name__ == '__main__':
    demo.launch()