presgpt / app.py
edwjin's picture
Update app.py
beda507 verified
import pickle
import tiktoken
import torch
import torch.nn.functional as F
import gradio as gr
from PresGPT2 import GPTConfig, PresGPT2
# Load tokenizer and model configuration
with open("pres_tokenizer.pkl", "rb") as f:
pres_enc: tiktoken.Encoding = pickle.load(f)
config: GPTConfig = GPTConfig(
1024,
len(pres_enc._mergeable_ranks) + len(pres_enc._special_tokens),
12,
12,
768
)
checkpoint = torch.load('./checkpoint.pt', map_location=torch.device("cpu")) # Adjust device if needed
# Initialize model
model: PresGPT2 = PresGPT2(config)
model.load_state_dict(checkpoint)
model.eval()
president_names = ["Abraham Lincoln", "Barack Obama", "Bill Clinton", "Donald Trump", "George Washington",
"Harry S. Truman", "Joe Biden", "John F. Kennedy", "Lyndon B. Johnson", "Richard Nixon"]
# Function to generate text based on input
def generate_text(president: str, input_text: str, generation_len: int, top_k: int, temperature: float):
prompt_text = f"<President: {president}> {input_text}"
# Tokenize the input text
text_tokens = torch.tensor(pres_enc.encode(prompt_text, allowed_special='all'))
text_tokens = torch.unsqueeze(text_tokens, 0)
# Generate text
for i in range(generation_len):
logits, _ = model(text_tokens)
logits = logits[:, -1, :] / temperature # B x T x vocab_size -> B x vocab_size
values, _ = torch.topk(logits, top_k) # values is descending
logits[logits < values[:, [-1]]] = -float('Inf') # for all possible logits for a sequence, if less than smallest topk set to negative inf
# Softmax over logits to get probabilities
probs = F.softmax(logits, dim=1) # dim=1 means we compute softmax over COLUMNS IN A ROW!!
# Introduce variability in the generated text
next_idx = torch.multinomial(probs, num_samples=1)
text_tokens = torch.cat((text_tokens, next_idx), dim=1)
# Decode the generated tokens into text
generated_text = pres_enc.decode(text_tokens.squeeze().tolist())
return generated_text
# Create Gradio interface
iface = gr.Interface(fn=generate_text,
inputs=[
gr.Dropdown(choices=president_names, label="Select a President"),
gr.Textbox(label="Enter a prompt", lines=2),
gr.Slider(minimum=10, maximum=200, value=100, label="Generation Length"),
gr.Slider(minimum=1, maximum=500, value=200, label="Top-k"),
gr.Slider(minimum=0.1, maximum=1.5, step=0.1, value=0.8, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text", lines=10),
title="PresGPT2 Text Generator",
description="Enter a prompt to generate something a president would say!")
if __name__ == "__main__":
iface.launch()