DS9Bot / app.py
Xibanya's picture
trying to force a rebuild
7de017e
import gradio as gr
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
with open('config.yaml', "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
def split_dialogue(line):
tokens = line.split()
concat = ''
total = len(tokens)
for t in range(total):
token = tokens[t]
if token != '[OC]:' and '[on ' not in token and \
('[' in token or
(token.isupper() and (':' in token or (t < total - 1 and tokens[t + 1] == '[OC]:')))):
token = '\n' + token
concat = concat + token + ' '
return concat.strip()
def generate(prompt: str = ' ', length: float = 250, temp: float = 0.7, top_k: float = 40):
torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt
end = "<|endoftext|>"
prompt = end + prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt").input_ids
encoded_prompt = encoded_prompt.to(model.device)
output = model.generate(
input_ids=encoded_prompt,
max_length=int(length) + len(encoded_prompt[0]),
min_length=int(length) + len(encoded_prompt[0]),
temperature=temp,
top_p=cfg['top_p'],
do_sample=True,
top_k=top_k,
early_stopping=True,
num_return_sequences=1
)
text = tokenizer.batch_decode(
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
return split_dialogue(text)
title = 'Deep Space Nine Script Generator'
iface = gr.Interface(fn=generate,
inputs=[gr.inputs.Textbox(label="Prompt", placeholder='Enter a prompt to generate a script from'),
gr.inputs.Slider(minimum=cfg['min_length'], maximum=cfg['max_length'], label="Output Length", default=500, step=1),
gr.inputs.Slider(minimum=cfg['min_temp'], maximum=cfg['max_temp'], label="Temperature", default=0.7),
gr.inputs.Slider(minimum=10, maximum=100, label="Top K", default=40, step=1)],
outputs=[gr.outputs.Textbox(type="auto", label="Generated Script")],
examples=[['[Promenade]', 500, 0.9, 40], ['QUARK:', 250, 1.2, 60], ["Commander's log, stardate 46924.5.", 300, 0.7, 60]],
live=False,
title=title,
theme='dark-huggingface',
)
iface.launch()