Spaces:
Build error
Build error
| import gradio as gr | |
| from modeling import global_config, ToyTransformer, AttentionBackend | |
| import torch | |
| from tokenizers import TRIETokenizer | |
| from threading import Thread | |
| import bisect | |
| if torch.cuda.is_available(): | |
| g_device = torch.device('cpu') | |
| else: | |
| g_device = torch.device('cpu') | |
| global_config['attn_backend'] = AttentionBackend.Naive | |
| g_SEQ_LEN = 1024 | |
| g_HIDDEN_SIZE = 768 | |
| g_NUM_HEADS = 12 | |
| g_NUM_LAYERS = 12 | |
| g_DTYPE = torch.float32 | |
| g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json') | |
| g_model = ToyTransformer(g_tokenizer.get_vocab_size(), g_NUM_LAYERS, g_NUM_HEADS, g_HIDDEN_SIZE, g_SEQ_LEN, g_device, g_DTYPE) | |
| g_model.load_state_dict(torch.load('model.pt', map_location='cpu')) | |
| def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty, | |
| max_new_tokens=20, total_tokens=None, | |
| end_tokens=None, | |
| enable_kv_cache=True): | |
| model.eval() | |
| feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt | |
| all_tokens = feed_tokens.copy() | |
| if total_tokens is not None: | |
| max_new_tokens = max(0, total_tokens - len(feed_tokens)) | |
| with torch.no_grad(): | |
| kv_cache = None | |
| for _ in range(max_new_tokens): | |
| logits, kv_cache = model.forward( | |
| torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device), | |
| kv_cache=kv_cache) | |
| logits = logits[0][-1].cpu() | |
| if not enable_kv_cache: | |
| kv_cache = None | |
| # apply repetition penalty | |
| logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens)) | |
| logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty) | |
| logits.scatter_(0, torch.tensor(all_tokens), logits_rep) | |
| # apply temperature | |
| logits /= max(temperature, 1e-6) | |
| probs = torch.softmax(logits, dim=0) | |
| # apply top-p | |
| ordered_probs, ordered_indices = torch.sort(probs, descending=True) | |
| cum_probs = torch.cumsum(ordered_probs, dim=0).tolist() | |
| top_p_index = bisect.bisect_right(cum_probs, top_p) + 1 | |
| ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index] | |
| sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item() | |
| all_tokens.append(sampled_index) | |
| feed_tokens = [sampled_index] | |
| if end_tokens is not None and sampled_index in end_tokens: | |
| break | |
| yield feed_tokens | |
| return | |
| def predict(user_input, history, max_length, top_p, temperature, rep_penalty, retry): | |
| if retry and len(history) == 0: | |
| yield [] | |
| return | |
| elif retry: | |
| user_input = history[-1][0] | |
| history = history[:-1] | |
| history.append((user_input, "")) | |
| encoded_inputs = [(g_tokenizer.encode('User:' + h[0]), g_tokenizer.encode('Assistant:' + h[1])) for h in history] | |
| taken_rounds, taken_rounds_length = [], 0 | |
| while len(taken_rounds) < len(encoded_inputs): | |
| round_pair = encoded_inputs[len(encoded_inputs) - 1 - len(taken_rounds)] | |
| if len(round_pair[0]) + len(round_pair[1]) + taken_rounds_length >= g_SEQ_LEN - max_length: | |
| break | |
| taken_rounds.append(round_pair) | |
| taken_rounds_length += len(round_pair[0]) + len(round_pair[1]) | |
| taken_rounds = taken_rounds[::-1] | |
| input_tokens = g_tokenizer.encode('<s>A chat between User and Assistant.') | |
| for round_pair in taken_rounds: | |
| input_tokens += g_tokenizer.encode('\n') + round_pair[0] + g_tokenizer.encode('\n') + round_pair[1] | |
| # print(taken_rounds, g_tokenizer.decode(input_tokens)) | |
| for response in generate(g_model, g_tokenizer, input_tokens, temperature, top_p, rep_penalty, max_length, end_tokens=g_tokenizer.encode('</s>')): | |
| history[-1] = (history[-1][0], history[-1][1] + g_tokenizer.decode(response)) | |
| yield history | |
| def main(): | |
| css = ''' | |
| .contain {max-width:50} | |
| #chatbot {min-height:500px} | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML('<h1 align="center">ToyTransformer</h1><h5 align="center">(Note: Please refresh if the page is not responsive.)</h5>') | |
| chatbot = gr.Chatbot(elem_id='chatbot') | |
| with gr.Column(): | |
| user_input = gr.Textbox(show_label=False, placeholder="Input", lines=1, container=False) | |
| with gr.Row(): | |
| submitBtn = gr.Button("Send", variant="primary") | |
| retryBtn = gr.Button("Retry") | |
| cancelBtn = gr.Button('Undo') | |
| emptyBtn = gr.Button("Clear") | |
| with gr.Row(): | |
| max_length = gr.Slider(0, 512, value=200, step=1, label="Max Response Tokens", interactive=True) | |
| top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True) | |
| temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True) | |
| rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True) | |
| submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)], | |
| [chatbot], show_progress=False) | |
| submitBtn.click(lambda: '', [], [user_input], show_progress=False) | |
| retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)], | |
| [chatbot], show_progress=False) | |
| cancelBtn.click(lambda m: m[:-1], [chatbot], [chatbot], show_progress=False) | |
| emptyBtn.click(lambda: [], outputs=[chatbot], show_progress=False) | |
| demo.queue().launch(share=False, inbrowser=True) | |
| main() | |