File size: 2,963 Bytes
beda507
 
 
 
b3edaa5
 
beda507
b3edaa5
beda507
 
 
b3edaa5
beda507
 
 
 
 
 
 
b3edaa5
beda507
b3edaa5
beda507
 
 
 
b3edaa5
beda507
 
b3edaa5
beda507
 
 
 
 
 
 
b3edaa5
beda507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3edaa5
beda507
 
 
 
 
 
 
 
 
 
 
 
b3edaa5
 
 
beda507
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle
import tiktoken
import torch
import torch.nn.functional as F
import gradio as gr

from PresGPT2 import GPTConfig, PresGPT2

# Load tokenizer and model configuration
with open("pres_tokenizer.pkl", "rb") as f:
    pres_enc: tiktoken.Encoding = pickle.load(f)

config: GPTConfig = GPTConfig(
    1024,
    len(pres_enc._mergeable_ranks) + len(pres_enc._special_tokens),
    12,
    12,
    768
)

checkpoint = torch.load('./checkpoint.pt', map_location=torch.device("cpu"))  # Adjust device if needed

# Initialize model
model: PresGPT2 = PresGPT2(config)
model.load_state_dict(checkpoint)
model.eval()

president_names = ["Abraham Lincoln", "Barack Obama", "Bill Clinton", "Donald Trump", "George Washington", 
                   "Harry S. Truman", "Joe Biden", "John F. Kennedy", "Lyndon B. Johnson", "Richard Nixon"]

# Function to generate text based on input
def generate_text(president: str, input_text: str, generation_len: int, top_k: int, temperature: float):
    prompt_text = f"<President: {president}> {input_text}"
    
    # Tokenize the input text
    text_tokens = torch.tensor(pres_enc.encode(prompt_text, allowed_special='all'))
    text_tokens = torch.unsqueeze(text_tokens, 0)

    # Generate text
    for i in range(generation_len):
        logits, _ = model(text_tokens)
        
        logits = logits[:, -1, :] / temperature  # B x T x vocab_size -> B x vocab_size
        
        values, _ = torch.topk(logits, top_k)  # values is descending
        logits[logits < values[:, [-1]]] = -float('Inf')  # for all possible logits for a sequence, if less than smallest topk set to negative inf
        
        # Softmax over logits to get probabilities
        probs = F.softmax(logits, dim=1)  # dim=1 means we compute softmax over COLUMNS IN A ROW!!
        
        # Introduce variability in the generated text
        next_idx = torch.multinomial(probs, num_samples=1)
        
        text_tokens = torch.cat((text_tokens, next_idx), dim=1)
    
    # Decode the generated tokens into text
    generated_text = pres_enc.decode(text_tokens.squeeze().tolist())
    return generated_text

# Create Gradio interface
iface = gr.Interface(fn=generate_text, 
                     inputs=[
                        gr.Dropdown(choices=president_names, label="Select a President"),
                        gr.Textbox(label="Enter a prompt", lines=2),
                        gr.Slider(minimum=10, maximum=200, value=100, label="Generation Length"),
                        gr.Slider(minimum=1, maximum=500, value=200, label="Top-k"),
                        gr.Slider(minimum=0.1, maximum=1.5, step=0.1, value=0.8, label="Temperature")
                     ],
                     outputs=gr.Textbox(label="Generated Text", lines=10),
                     title="PresGPT2 Text Generator",
                     description="Enter a prompt to generate something a president would say!")


if __name__ == "__main__":
    iface.launch()