| import gradio as gr | |
| import torch | |
| import yaml | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| with open('config.yaml', "r") as f: | |
| cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot") | |
| model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot") | |
| def split_dialogue(line): | |
| tokens = line.split() | |
| concat = '' | |
| total = len(tokens) | |
| for t in range(total): | |
| token = tokens[t] | |
| if token != '[OC]:' and '[on ' not in token and \ | |
| ('[' in token or | |
| (token.isupper() and (':' in token or (t < total - 1 and tokens[t + 1] == '[OC]:')))): | |
| token = '\n' + token | |
| concat = concat + token + ' ' | |
| return concat.strip() | |
| def generate(prompt: str = ' ', length: float = 250, temp: float = 0.7, top_k: float = 40): | |
| torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed()) | |
| prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt | |
| end = "<|endoftext|>" | |
| prompt = end + prompt | |
| encoded_prompt = tokenizer(prompt, return_tensors="pt").input_ids | |
| encoded_prompt = encoded_prompt.to(model.device) | |
| output = model.generate( | |
| input_ids=encoded_prompt, | |
| max_length=int(length) + len(encoded_prompt[0]), | |
| min_length=int(length) + len(encoded_prompt[0]), | |
| temperature=temp, | |
| top_p=cfg['top_p'], | |
| do_sample=True, | |
| top_k=top_k, | |
| early_stopping=True, | |
| num_return_sequences=1 | |
| ) | |
| text = tokenizer.batch_decode( | |
| output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] | |
| return split_dialogue(text) | |
| title = 'Deep Space Nine Script Generator' | |
| iface = gr.Interface(fn=generate, | |
| inputs=[gr.inputs.Textbox(label="Prompt", placeholder='Enter a prompt to generate a script from'), | |
| gr.inputs.Slider(minimum=cfg['min_length'], maximum=cfg['max_length'], label="Output Length", default=500, step=1), | |
| gr.inputs.Slider(minimum=cfg['min_temp'], maximum=cfg['max_temp'], label="Temperature", default=0.7), | |
| gr.inputs.Slider(minimum=10, maximum=100, label="Top K", default=40, step=1)], | |
| outputs=[gr.outputs.Textbox(type="auto", label="Generated Script")], | |
| examples=[['[Promenade]', 500, 0.9, 40], ['QUARK:', 250, 1.2, 60], ["Commander's log, stardate 46924.5.", 300, 0.7, 60]], | |
| live=False, | |
| title=title, | |
| theme='dark-huggingface', | |
| ) | |
| iface.launch() |