Spaces:
Paused
Paused
| """ | |
| Model by @duyphung for @carperai | |
| Dumb Simple Gradio by @jon-tow | |
| """ | |
| from string import Template | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained("CarperAI/vicuna-13b-fine-tuned-rlhf") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "CarperAI/vicuna-13b-fine-tuned-rlhf", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model.cuda() | |
| max_context_length = model.config.max_position_embeddings | |
| max_new_tokens = 256 | |
| prompt_template = Template("""\ | |
| ### Human: $human | |
| ### Assistant: $bot\ | |
| """) | |
| def bot(history): | |
| history = history or [] | |
| # Hack to inject prompt formatting into the history | |
| prompt_history = [] | |
| for human, bot in history: | |
| prompt_history.append( | |
| prompt_template.substitute( | |
| human=human, bot=bot if bot is not None else "") | |
| ) | |
| prompt = "\n\n".join(prompt_history) | |
| prompt = prompt.rstrip() | |
| inputs = tokenizer(prompt, return_tensors='pt').to(model.device) | |
| # Use only the most recent context up to the maximum context length with room left over | |
| # for the max new tokens | |
| inputs = {k: v[:, -max_context_length + max_new_tokens:] for k, v in inputs.items()} | |
| inputs_length = inputs['input_ids'].shape[1] | |
| # Generate the response | |
| tokens = model.generate( | |
| **inputs, | |
| # Only allow the model to generate up to 512 tokens | |
| max_new_tokens=max_new_tokens, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_p=1.0, | |
| ) | |
| # Strip the initial prompt | |
| tokens = tokens[:, inputs_length:] | |
| # Process response | |
| response = tokenizer.decode(tokens[0], skip_special_tokens=True) | |
| response = response.split("###")[0].strip() | |
| # Add the response to the history | |
| history[-1][1] = response | |
| return history | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""Vicuna-13B RLHF Chatbot""") | |
| chatbot = gr.Chatbot([], elem_id="chatbot").style(height=512) | |
| msg = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| state = gr.State([]) | |
| msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch(share=True) | |