Spaces:
Runtime error
Runtime error
| ''' | |
| CONFIG AND IMPORTS | |
| ''' | |
| from config import default_config | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| import os, random | |
| from pathlib import Path | |
| import tiktoken | |
| from getpass import getpass | |
| from openai import OpenAI | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "")) | |
| from langchain.text_splitter import MarkdownHeaderTextSplitter | |
| import numpy as np | |
| # from langchain.embeddings import OpenAIEmbeddings | |
| from langchain_openai import OpenAIEmbeddings | |
| # from langchain.vectorstores import Chroma | |
| from typing import Iterable | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import time | |
| if os.getenv("OPENAI_API_KEY") is None: | |
| if any(['VSCODE' in x for x in os.environ.keys()]): | |
| print('Please enter password in the VS Code prompt at the top of your VS Code window!') | |
| os.environ["OPENAI_API_KEY"] = getpass("Paste your OpenAI key from: https://platform.openai.com/account/api-keys\n") | |
| assert os.getenv("OPENAI_API_KEY", "").startswith("sk-"), "This doesn't look like a valid OpenAI API key" | |
| print("OpenAI API key configured") | |
| embeddings_model = OpenAIEmbeddings() | |
| md = "" | |
| directory_path = "safety_docs" | |
| for filename in os.listdir(directory_path): | |
| if filename.endswith(".md"): | |
| with open(os.path.join(directory_path, filename), 'r') as file: | |
| content = file.read() | |
| md = md + content | |
| markdown_document = md | |
| headers_to_split_on = [ | |
| ("#", "Header 1"), | |
| ("##", "Header 2"), | |
| ("###", "Header 3"), | |
| ] | |
| markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) | |
| md_header_splits = markdown_splitter.split_text(markdown_document) | |
| def find_nearest_neighbor(argument="", max_args_in_output=2): | |
| ''' | |
| INPUT: | |
| argument (string) | |
| RETURN the nearest neighbor(s) in vectorDB to argument as string | |
| ''' | |
| embeddings = embeddings_model | |
| embedding_matrix = np.array([embeddings.embed_query(text.page_content) for text in md_header_splits]) | |
| argument_embedding = embeddings.embed_query(argument) | |
| dot_products = np.dot(embedding_matrix, argument_embedding) | |
| norms = np.linalg.norm(embedding_matrix, axis=1) * np.linalg.norm(argument_embedding) | |
| cosine_similarities = dot_products / norms | |
| nearest_indices = np.argsort(cosine_similarities)[-max_args_in_output:][::-1] | |
| arr = [md_header_splits[index].metadata for index in nearest_indices] | |
| output = "" | |
| for thing in arr: | |
| output = output + thing['Header 1'] + "\n" | |
| return output | |
| def get_gpt_response(user_prompt, system_prompt=default_config.system_prompt, model=default_config.model_name, n=1, max_tokens=200): | |
| ''' | |
| INPUT: | |
| Argument | |
| user_prompt | |
| system_prompt | |
| model | |
| ''' | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| response = client.chat.completions.create(model=model, | |
| messages=messages, | |
| n=n, | |
| max_tokens=max_tokens) | |
| for choice in response.choices: | |
| generation = choice.message.content | |
| return generation | |
| # return the gpt generated response | |
| def greet1(argument): | |
| user_prompt = default_config.user_prompt_1 + argument + default_config.user_prompt_2 | |
| response = get_gpt_response(user_prompt=user_prompt) | |
| return response | |
| # return the nearest neighbor arguments | |
| def greet2(argument): | |
| nearest_neighbor = find_nearest_neighbor(argument) | |
| return "Your argument may fall under the common arguments against AI safety. \nIs it one of these? \n" + nearest_neighbor + "\nSee the taxonomy of arguments below" | |
| # theme = gr.themes.Monochrome() | |
| theme = gr.themes.Monochrome( | |
| # neutral_hue=gr.themes.colors.red, | |
| # n, boxes, text, nothing bottom text most text | |
| neutral_hue=gr.themes.Color("red", "#636363", "#636363", "lightgrey", "lightgrey", "lightgrey", "lightgrey", "grey", "red", "black", "red"), | |
| primary_hue=gr.themes.Color("#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010", "#8c0010"), | |
| secondary_hue=gr.themes.Color("white", "white", "white", "white", "white", "white", "white", "white", "white", "white", "white"), | |
| ) | |
| theme = theme.set( | |
| body_background_fill="black", | |
| block_title_background_fill="black", | |
| block_background_fill="black", | |
| body_text_color="white", | |
| link_text_color='*primary_50', | |
| link_text_color_dark='*primary_50', | |
| link_text_color_active='*primary_50', | |
| link_text_color_active_dark='*primary_50', | |
| link_text_color_hover='*primary_50', | |
| link_text_color_hover_dark='*primary_50', | |
| link_text_color_visited='*primary_50', | |
| link_text_color_visited_dark='*primary_50' | |
| ) | |
| css_string = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Gabarito&family=Gothic+A1:wght@100;200;300;400;500;600;700;800;900&display=swap'); | |
| force_black_bg { | |
| color: white !important; | |
| font-family: 'Gabarito', cursive !important; | |
| } | |
| force_black_bg *{ | |
| font-family: 'Gabarito', cursive !important; | |
| } | |
| footer{ | |
| display:none !important | |
| } | |
| """ | |
| css_string2 = "" | |
| # theme=theme, | |
| # choose a monochrome theme | |
| # theme = gr.themes.Monochrome() | |
| # theme = gr.themes.Base() | |
| theme = gr.themes.Default(text_size="lg") | |
| # theme = gr.themes.Glass() | |
| # theme = gr.themes.Monochrome() | |
| # theme = gr.themes.Soft() | |
| # with gr.Blocks(theme=theme) as demo: | |
| # with gr.Row(elem_id="force_black_bg"): | |
| # with gr.Column(elem_id="force_black_bg"): | |
| # seed = gr.Textbox( label="AI Safety Skepticism: What's Your Take?", placeholder="Enter an argument or something you'd like to say!") | |
| # btn = gr.Button("Generate >") | |
| # with gr.Column(): | |
| # german = gr.Chatbot() | |
| # # german = gr.Text(label="Safetybot Response") | |
| # english = gr.Text(elem_id="themed_question_box", label="Common Argument Classifier") | |
| # btn.click(greet2, inputs=[seed],outputs=english) | |
| # btn.click(greet1, inputs=[seed],outputs=german) | |
| # # def respond(message, chat_history): | |
| # # bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) | |
| # # chat_history.append((message, bot_message)) | |
| # # return "", chat_history | |
| # # btn.submit(respond, [seed, chatbot], [msg, chatbot]) | |
| # gr.Examples(["AGI is far away, I'm not worried", "AI is confined to a computer and cannot interact with the physical world", "AI isn't conscious", "If we don't develop AGI, China will!", "If we don't develop AGI, the Americans will!"], inputs=[seed]) | |
| with gr.Blocks(theme=theme) as demo: | |
| with gr.Row(): | |
| with gr.Column(elem_id="force_black_bg"): | |
| msg = gr.Textbox(label="AI Safety Skepticism: What's Your Take?", placeholder="Enter an argument or something you'd like to say!") | |
| btn = gr.Button("Generate >") | |
| gr.Examples(["AGI is far away, I'm not worried", "AI is confined to a computer", "AI isn't conscious", "If we don't develop AGI, China will!", "If we don't develop AGI, the Americans will!"], inputs=[msg]) | |
| with gr.Column(elem_id="force_black_bg"): | |
| chatbot = gr.Chatbot() | |
| english = gr.Text(elem_id="themed_question_box", label="Common Argument Classifier") | |
| def respond(message, chat_history): | |
| # bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) | |
| bot_message = get_gpt_response(user_prompt=message) | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| btn.click(respond, [msg, chatbot], [msg, chatbot]) | |
| msg.submit(greet2, inputs=[msg],outputs=english) | |
| btn.click(greet2, inputs=[msg],outputs=english) | |
| demo.queue() | |
| demo.launch() | |