|
|
import pickle |
|
|
import tiktoken |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
|
|
|
from PresGPT2 import GPTConfig, PresGPT2 |
|
|
|
|
|
|
|
|
with open("pres_tokenizer.pkl", "rb") as f: |
|
|
pres_enc: tiktoken.Encoding = pickle.load(f) |
|
|
|
|
|
config: GPTConfig = GPTConfig( |
|
|
1024, |
|
|
len(pres_enc._mergeable_ranks) + len(pres_enc._special_tokens), |
|
|
12, |
|
|
12, |
|
|
768 |
|
|
) |
|
|
|
|
|
checkpoint = torch.load('./checkpoint.pt', map_location=torch.device("cpu")) |
|
|
|
|
|
|
|
|
model: PresGPT2 = PresGPT2(config) |
|
|
model.load_state_dict(checkpoint) |
|
|
model.eval() |
|
|
|
|
|
president_names = ["Abraham Lincoln", "Barack Obama", "Bill Clinton", "Donald Trump", "George Washington", |
|
|
"Harry S. Truman", "Joe Biden", "John F. Kennedy", "Lyndon B. Johnson", "Richard Nixon"] |
|
|
|
|
|
|
|
|
def generate_text(president: str, input_text: str, generation_len: int, top_k: int, temperature: float): |
|
|
prompt_text = f"<President: {president}> {input_text}" |
|
|
|
|
|
|
|
|
text_tokens = torch.tensor(pres_enc.encode(prompt_text, allowed_special='all')) |
|
|
text_tokens = torch.unsqueeze(text_tokens, 0) |
|
|
|
|
|
|
|
|
for i in range(generation_len): |
|
|
logits, _ = model(text_tokens) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
values, _ = torch.topk(logits, top_k) |
|
|
logits[logits < values[:, [-1]]] = -float('Inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=1) |
|
|
|
|
|
|
|
|
next_idx = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
text_tokens = torch.cat((text_tokens, next_idx), dim=1) |
|
|
|
|
|
|
|
|
generated_text = pres_enc.decode(text_tokens.squeeze().tolist()) |
|
|
return generated_text |
|
|
|
|
|
|
|
|
iface = gr.Interface(fn=generate_text, |
|
|
inputs=[ |
|
|
gr.Dropdown(choices=president_names, label="Select a President"), |
|
|
gr.Textbox(label="Enter a prompt", lines=2), |
|
|
gr.Slider(minimum=10, maximum=200, value=100, label="Generation Length"), |
|
|
gr.Slider(minimum=1, maximum=500, value=200, label="Top-k"), |
|
|
gr.Slider(minimum=0.1, maximum=1.5, step=0.1, value=0.8, label="Temperature") |
|
|
], |
|
|
outputs=gr.Textbox(label="Generated Text", lines=10), |
|
|
title="PresGPT2 Text Generator", |
|
|
description="Enter a prompt to generate something a president would say!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |