Spaces:
Runtime error
Runtime error
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| import gradio as gr | |
| import logging | |
| from huggingface_hub import login | |
| import os | |
| import traceback | |
| from threading import Thread | |
| from random import shuffle, choice | |
| import json | |
| import gspread | |
| from google.oauth2.service_account import Credentials | |
| logging.basicConfig(level=logging.DEBUG) | |
| SPACER = '\n' + '*' * 40 + '\n' | |
| SCOPES = ['https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive'] #spread scopes | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| login(token=HF_TOKEN) | |
| system_prompts = { | |
| "English": "You are a helpful chatbot that answers user input in a concise and witty way.", | |
| "German": "Du bist ein hilfreicher Chatbot, der Usereingaben knapp und originell beantwortet.", | |
| "French": "Tu es un chatbot utile qui répond aux questions des utilisateurs de manière concise et originale.", | |
| "Spanish": "Eres un chatbot servicial que responde a las entradas de los usuarios de forma concisa y original." | |
| } | |
| htmL_info = "<center><h1>⚔️ Pharia Bot Battle</h1><p><big>Let the games begin: In this arena, the <a href='https://huggingface.co/Aleph-Alpha/Pharia-1-LLM-7B-control-hf'>Pharia 1 model</a> competes against secret challengers of comparable size.</p><ul><li>Try a prompt in a language you want to explore</li><li>Set the parameters and vote for the best answers</li><li>After casting your vote, both bots reveal their identity</li><p>Please note that inputs, outputs and votes are logged anonymously. Feel free to use the bot if you’re cool with that!</p></big></center>" | |
| model_info = [{"id": "Aleph-Alpha/Pharia-1-LLM-7B-control-hf", | |
| "name": "Pharia 1 LLM 7B control hf"}] | |
| challenger_models = [{"id": "NousResearch/Meta-Llama-3.1-8B-Instruct", | |
| "name": "Meta Llama 3.1 8B Instruct"}, | |
| {"id": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "name": "Mistral 7B Instruct v0.3"}] | |
| challenger_model = choice(challenger_models) | |
| model_info.append(challenger_model) | |
| shuffle(model_info) | |
| chatbot_a_name = model_info[0]['name'] | |
| chatbot_b_name = model_info[1]['name'] | |
| device = "cuda" | |
| try: | |
| tokenizer_a = AutoTokenizer.from_pretrained(model_info[0]['id']) | |
| model_a = AutoModelForCausalLM.from_pretrained( | |
| model_info[0]['id'], | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer_b = AutoTokenizer.from_pretrained(model_info[1]['id']) | |
| model_b = AutoModelForCausalLM.from_pretrained( | |
| model_info[1]['id'], | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| def get_google_credentials(): | |
| """Sets credentials for remote sheet""" | |
| service_account_info = { | |
| "type": "service_account", | |
| "project_id": os.environ.get("GOOGLE_PROJECT_ID"), | |
| "private_key_id": os.environ.get("GOOGLE_PRIVATE_KEY_ID"), | |
| "private_key": os.environ.get("GOOGLE_PRIVATE_KEY").replace('\\n', '\n'), | |
| "client_email": os.environ.get("GOOGLE_CLIENT_EMAIL"), | |
| "client_id": os.environ.get("GOOGLE_CLIENT_ID"), | |
| "auth_uri": os.environ.get("GOOGLE_AUTH_URI"), | |
| "token_uri": os.environ.get("GOOGLE_TOKEN_URI"), | |
| "auth_provider_x509_cert_url": os.environ.get("GOOGLE_AUTH_PROVIDER_CERT_URL"), | |
| "client_x509_cert_url": os.environ.get("GOOGLE_CLIENT_CERT_URL") | |
| } | |
| credentials = Credentials.from_service_account_info(service_account_info,scopes=SCOPES) | |
| return credentials | |
| def get_google_sheet(): | |
| """Intits auth, gets and returns instance of remote sheet""" | |
| credentials = get_google_credentials() | |
| client = gspread.authorize(credentials) | |
| sheet = client.open("pharia_bot_battle_logs").sheet1 # Open your Google Sheet | |
| return sheet | |
| def apply_pharia_template(messages, add_generation_prompt=False): | |
| """Chat template not defined in Pharia model configs. | |
| Adds chat template for Pharia. Expects a list of messages. | |
| add_generation_prompt:bool extends tmplate for generation. | |
| """ | |
| pharia_template = """<|begin_of_text|>""" | |
| role_map = { | |
| "system": "<|start_header_id|>system<|end_header_id|>\n", | |
| "user": "<|start_header_id|>user<|end_header_id|>\n", | |
| "assistant": "<|start_header_id|>assistant<|end_header_id|>\n", | |
| } | |
| for message in messages: | |
| role = message["role"] | |
| content = message["content"] | |
| pharia_template += role_map.get(role, "") + content + "<|eot_id|>\n" | |
| if add_generation_prompt: | |
| pharia_template += "<|start_header_id|>assistant<|end_header_id|>\n" | |
| return pharia_template | |
| def generate_both(system_prompt, input_text, | |
| chatbot_a, chatbot_b, | |
| max_new_tokens=2048, temperature=0.2, | |
| top_p=0.9, repetition_penalty=1.1): | |
| try: | |
| text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True) | |
| text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True) | |
| system_prompt_list = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
| input_text_list = [{"role": "user", "content": input_text}] | |
| chat_history_a = [] | |
| for user, assistant in chatbot_a: | |
| chat_history_a.append({"role": "user", "content": user}) | |
| chat_history_a.append({"role": "assistant", "content": assistant}) | |
| chat_history_b = [] | |
| for user, assistant in chatbot_b: | |
| chat_history_b.append({"role": "user", "content": user}) | |
| chat_history_b.append({"role": "assistant", "content": assistant}) | |
| new_messages_a = system_prompt_list + chat_history_a + input_text_list | |
| new_messages_b = system_prompt_list + chat_history_b + input_text_list | |
| logging.debug(f'{SPACER}\nNew message bot A: \n{new_messages_a}\n{SPACER}') | |
| logging.debug(f'{SPACER}\nNew message bot B: \n{new_messages_b}\n{SPACER}') | |
| if "Pharia" in model_info[0]['id']: | |
| formatted_conversation = apply_pharia_template(messages=new_messages_a, add_generation_prompt=True) | |
| tokenized = tokenizer_a(formatted_conversation, return_tensors="pt").to(device) | |
| #logging.debug(tokenized) #attention_mask | |
| input_ids_a = tokenized.input_ids | |
| tokenizer_a.eos_token = "<|endoftext|>" # not set für Pharia | |
| tokenizer_a.pad_token = "<|padding|>" # not set für Pharia | |
| else: | |
| input_ids_a = tokenizer_a.apply_chat_template( | |
| new_messages_a, | |
| add_generation_prompt=True, | |
| dtype=torch.float16, | |
| return_tensors="pt" | |
| ).to(device) | |
| if "Pharia" in model_info[1]['id']: | |
| formatted_conversation = apply_pharia_template(messages=new_messages_a, add_generation_prompt=True) | |
| tokenized = tokenizer_b(formatted_conversation, return_tensors="pt").to(device) | |
| #logging.debug(tokenized) | |
| input_ids_b = tokenized.input_ids | |
| tokenizer_b.eos_token = "<|endoftext|>" # not set für Pharia | |
| tokenizer_b.pad_token = "<|padding|>" # not set für Pharia | |
| else: | |
| input_ids_b = tokenizer_b.apply_chat_template( | |
| new_messages_b, | |
| add_generation_prompt=True, | |
| dtype=torch.float16, | |
| return_tensors="pt" | |
| ).to(device) | |
| generation_kwargs_a = dict( | |
| input_ids=input_ids_a, | |
| streamer=text_streamer_a, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer_a.eos_token_id, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| generation_kwargs_b = dict( | |
| input_ids=input_ids_b, | |
| streamer=text_streamer_b, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer_b.eos_token_id, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| thread_a = Thread(target=model_a.generate, kwargs=generation_kwargs_a) | |
| thread_b = Thread(target=model_b.generate, kwargs=generation_kwargs_b) | |
| thread_a.start() | |
| thread_b.start() | |
| chatbot_a.append([input_text, ""]) | |
| chatbot_b.append([input_text, ""]) | |
| finished_a = False | |
| finished_b = False | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| while not (finished_a and finished_b): | |
| if not finished_a: | |
| try: | |
| text_a = next(text_streamer_a) | |
| if tokenizer_a.eos_token in text_a: | |
| eot_location = text_a.find(tokenizer_a.eos_token) | |
| text_a = text_a[:eot_location] | |
| finished_a = True | |
| chatbot_a[-1][-1] += text_a | |
| yield chatbot_a, chatbot_b | |
| except StopIteration: | |
| finished_a = True | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| if not finished_b: | |
| try: | |
| text_b = next(text_streamer_b) | |
| if tokenizer_b.eos_token in text_b: | |
| eot_location = text_b.find(tokenizer_b.eos_token) | |
| text_b = text_b[:eot_location] | |
| finished_b = True | |
| chatbot_b[-1][-1] += text_b | |
| yield chatbot_a, chatbot_b | |
| except StopIteration: | |
| finished_b = True | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| try: | |
| # chatbot_a[-1][1] Second index of last in list | |
| sheet_row = [system_prompt, input_text, max_new_tokens, temperature, top_p, repetition_penalty, chatbot_a_name, chatbot_a[-1][1], chatbot_b_name, chatbot_b[-1][1], "None", "None"] | |
| logging.debug(f'{SPACER}\nOutput row: {sheet_row}') | |
| sheet = get_google_sheet() | |
| sheet.append_row(sheet_row, table_range="A1:L1") | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| return chatbot_a, chatbot_b | |
| def clear(): | |
| return [], [] | |
| def handle_vote(selection, chatbot_a, chatbot_b): | |
| if selection == "Bot A kicks ass!": | |
| chatbot_a.append(["🏆", f"Thanks, man. I am {chatbot_a_name}"]) | |
| chatbot_b.append(["💩", f"Pffff … I am {chatbot_b_name}"]) | |
| chatbot_a_vote = "Winner" | |
| chatbot_b_vote = "Looser" | |
| elif selection == "Bot B crushes it!": | |
| chatbot_a.append(["🤡", f"Rigged … I am {chatbot_a_name}"]) | |
| chatbot_b.append(["🥇", f"Well deserved! I am {chatbot_b_name}"]) | |
| chatbot_a_vote = "Looser" | |
| chatbot_b_vote = "Winner" | |
| else: | |
| chatbot_a.append(["🤝", f"Lame … I am {chatbot_a_name}"]) | |
| chatbot_b.append(["🤝", f"Dunno. I am {chatbot_b_name}"]) | |
| chatbot_a_vote = "Draw" | |
| chatbot_b_vote = "Draw" | |
| try: | |
| # chatbot_a[-1][1] Second index of last in list | |
| sheet_row = ["None", "None", 0, 0, 0, 0, chatbot_a_name, "None", chatbot_b_name, "None", chatbot_a_vote, chatbot_b_vote] | |
| logging.debug(f'{SPACER}\nOutput row: {sheet_row}') | |
| sheet = get_google_sheet() | |
| sheet.append_row(sheet_row, table_range="A1:L1") | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| return chatbot_a, chatbot_b | |
| with gr.Blocks() as demo: | |
| try: | |
| with gr.Column(): | |
| gr.HTML(htmL_info) | |
| gr.HTML("<h2>Set Parameters</h2>") | |
| with gr.Row(variant="compact"): | |
| with gr.Column(scale=0): | |
| language_dropdown = gr.Dropdown(choices=["English", "German", "French", "Spanish"], label="Select Language for System Prompt",value="English") | |
| with gr.Column(): | |
| system_prompt = gr.Textbox(lines=1, label="System Prompt", value=system_prompts["English"], show_copy_button=True) | |
| with gr.Row(variant="compact"): | |
| with gr.Column(scale=1): | |
| submit_btn = gr.Button(value="Generate", variant="primary") | |
| clear_btn = gr.Button(value="Clear", variant="secondary") | |
| input_text = gr.Textbox(lines=1, label="Prompt", value="Write a Nike style ad headline about the shame of being second best.", scale=3, show_copy_button=True) | |
| with gr.Accordion(label="Generation Configurations", open=False): | |
| max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=512, label="Max new tokens", step=128) | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature", step=0.01) | |
| top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.97, label="Top_p", step=0.01) | |
| repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1) | |
| gr.HTML("<h2>Check outputs</h2>") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| chatbot_a = gr.Chatbot(label="Model A", show_copy_button=True, height=500) | |
| with gr.Column(): | |
| chatbot_b = gr.Chatbot(label="Model B", show_copy_button=True, height=500) | |
| gr.HTML("<h2>Vote!</h2>") | |
| with gr.Row(variant="panel"): | |
| better_bot = gr.Radio(["Bot A kicks ass!", "Bot B crushes it!", "It's a draw."], label="Rate the output!") | |
| language_dropdown.change( | |
| lambda lang: system_prompts[lang], | |
| inputs=[language_dropdown], | |
| outputs=[system_prompt] | |
| ) | |
| better_bot.select(handle_vote, inputs=[better_bot, chatbot_a, chatbot_b], outputs=[chatbot_a, chatbot_b]) | |
| input_text.submit(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) | |
| submit_btn.click(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) | |
| clear_btn.click(clear, outputs=[chatbot_a, chatbot_b]) | |
| except Exception as e: | |
| logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
| if __name__ == "__main__": | |
| demo.queue().launch() |