| |
|
|
| import os, requests, datetime |
| import streamlit as st |
| from functools import partial |
| from tempfile import NamedTemporaryFile |
| from typing import List, Callable, Literal, Optional |
| from streamlit.runtime.uploaded_file_manager import UploadedFile |
| from langchain_openai import ChatOpenAI |
| from langchain.schema import HumanMessage, SystemMessage |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_community.utilities import BingSearchAPIWrapper |
| from langchain_community.document_loaders import PyPDFLoader |
| from langchain_community.document_loaders import Docx2txtLoader |
| from langchain_community.document_loaders import TextLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.vectorstores import FAISS |
| from langchain_openai import OpenAIEmbeddings |
| from langchain.tools import Tool |
| from langchain.tools.retriever import create_retriever_tool |
| from langchain.agents import create_openai_tools_agent |
| |
| from langchain.agents import AgentExecutor |
| from langchain_community.agent_toolkits.load_tools import load_tools |
| from langchain.pydantic_v1 import BaseModel, Field |
|
|
| from TTS.api import TTS |
| import tempfile |
|
|
| |
| tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False) |
|
|
|
|
|
|
| def initialize_session_state_variables() -> None: |
| """ |
| Initialize all the session state variables. |
| """ |
| session_defaults = { |
| "ready": False, |
| "bing_subscription_validity": False, |
| "model": "gpt-4o", |
| "language": "English", |
| "topic": "", |
| "positive": "", |
| "negative": "", |
| "agent_descriptions": {}, |
| "new_debate": True, |
| "conversations": [], |
| "conversations4print": [], |
| "simulator": None, |
| "tools": [], |
| "retriever_tool": None, |
| "vector_store_message": "", |
| "conclusions": "", |
| "comments_key": 0, |
| "specified_topic": "", |
| } |
| for key, value in session_defaults.items(): |
| if key not in st.session_state: |
| st.session_state[key] = value |
|
|
|
|
| def initialize_api_keys(): |
| """ |
| Initialize API keys from Hugging Face secrets and validate them. |
| """ |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "") |
| os.environ["BING_SUBSCRIPTION_KEY"] = os.getenv("BING_SUBSCRIPTION_KEY", "") |
|
|
| |
| if not os.environ["OPENAI_API_KEY"]: |
| st.error("Missing OPENAI_API_KEY. Add it to Hugging Face Space secrets.") |
| if not os.environ["BING_SUBSCRIPTION_KEY"]: |
| st.warning("Missing BING_SUBSCRIPTION_KEY. Bing search may not work.") |
|
|
| initialize_api_keys() |
|
|
|
|
| def is_openai_api_key_valid(): |
| """ |
| Validate the OpenAI API key from Hugging Face secrets. |
| """ |
| openai_api_key = os.environ.get("OPENAI_API_KEY") |
| if not openai_api_key: |
| return False |
| headers = {"Authorization": f"Bearer {openai_api_key}"} |
| response = requests.get("https://api.openai.com/v1/models", headers=headers) |
| return response.status_code == 200 |
|
|
|
|
|
|
| def is_bing_subscription_key_valid(): |
| """ |
| Validate the Bing Subscription key from Hugging Face secrets. |
| """ |
| bing_subscription_key = os.environ.get("BING_SUBSCRIPTION_KEY") |
| if not bing_subscription_key: |
| return False |
| try: |
| bing_search = BingSearchAPIWrapper( |
| bing_subscription_key=bing_subscription_key, |
| bing_search_url="https://api.bing.microsoft.com/v7.0/search", |
| k=1 |
| ) |
| bing_search.run("Test Query") |
| except Exception: |
| return False |
| return True |
|
|
|
|
|
|
| def check_api_keys() -> None: |
| """ |
| Unset this flag to check the validity of the OpenAI API key. |
| """ |
|
|
| st.session_state.ready = False |
|
|
|
|
| def append_period(text: str) -> str: |
| """ |
| Append a '.' to the input text |
| if it is nonempty and does not end with '.' or '?'. |
| """ |
|
|
| if text and text[-1] not in (".", "?"): |
| text += "." |
| return text |
|
|
|
|
| def get_vector_store(uploaded_files: List[UploadedFile]) -> Optional[FAISS]: |
| """ |
| Take a list of UploadedFile objects as input, |
| and return a FAISS vector store. |
| """ |
|
|
| if not uploaded_files: |
| return None |
|
|
| documents = [] |
| filepaths = [] |
| loader_map = { |
| ".pdf": PyPDFLoader, |
| ".txt": TextLoader, |
| ".docx": Docx2txtLoader |
| } |
| try: |
| for uploaded_file in uploaded_files: |
| |
| with NamedTemporaryFile(dir="files/", delete=False) as file: |
| file.write(uploaded_file.getbuffer()) |
| filepath = file.name |
| filepaths.append(filepath) |
|
|
| file_ext = os.path.splitext(uploaded_file.name.lower())[1] |
| loader_class = loader_map.get(file_ext) |
| if not loader_class: |
| st.error(f"Unsupported file type: {file_ext}", icon="🚨") |
| for filepath in filepaths: |
| if os.path.exists(filepath): |
| os.remove(filepath) |
| return None |
|
|
| |
| loader = loader_class(filepath) |
| documents.extend(loader.load()) |
|
|
| with st.spinner("Vector DB in preparation..."): |
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| |
| ) |
| doc = text_splitter.split_documents(documents) |
| |
| embeddings = OpenAIEmbeddings( |
| model="text-embedding-3-large", dimensions=1536 |
| ) |
| vector_store = FAISS.from_documents(doc, embeddings) |
| except Exception as e: |
| vector_store = None |
| st.error(f"An error occurred: {e}", icon="🚨") |
| finally: |
| |
| for filepath in filepaths: |
| if os.path.exists(filepath): |
| os.remove(filepath) |
|
|
| return vector_store |
|
|
|
|
| def get_retriever() -> None: |
| """ |
| Upload document(s), create a vector store, prepare a retriever tool, |
| save the tool to the variable st.session_state.retriever_tool |
| """ |
|
|
| st.write("") |
| st.write("##### Document(s) to ask about") |
| uploaded_files = st.file_uploader( |
| label="Upload an article", |
| type=["txt", "pdf", "docx"], |
| accept_multiple_files=True, |
| label_visibility="collapsed", |
| ) |
|
|
| left, right = st.columns(2) |
| if left.button(label="$\:\!$Create a vector DB$\,$"): |
| |
| vector_store = get_vector_store(uploaded_files) |
|
|
| if vector_store is not None: |
| retriever = vector_store.as_retriever() |
| st.session_state.retriever_tool = create_retriever_tool( |
| retriever, |
| name="retriever", |
| description=( |
| "Search for information about the uploaded documents. " |
| "For any questions about the documents, you must use " |
| "this tool!" |
| ), |
| ) |
| st.session_state.vector_store_message = "Vector DB is ready!" |
|
|
| if st.session_state.vector_store_message: |
| right.write(f":blue[{st.session_state.vector_store_message}]") |
|
|
|
|
| class DialogueAgent: |
| """ |
| Class for an individual agent participating in the debate. |
| """ |
|
|
| def __init__( |
| self, |
| name: str, |
| system_message: SystemMessage, |
| llm: ChatOpenAI, |
| tools: List[str], |
| ) -> None: |
| self.name = name |
| self.system_message = system_message |
| self.llm = llm |
| self.prefix = f"{self.name}: " |
| self.tools = tools |
| self.reset() |
|
|
| def reset(self): |
| self.message_history = ["\nHere is the conversation so far.\n"] |
|
|
| def send(self) -> str: |
| """ |
| Apply the llm to the message history and return the message string. |
| """ |
| chat_prompt_list = [ |
| ("system", "You are a helpful assistant."), |
| ("human", "{input}"), |
| ] |
| agent_prompt_list = chat_prompt_list + [ |
| MessagesPlaceholder(variable_name="agent_scratchpad") |
| ] |
| chat_prompt = ChatPromptTemplate.from_messages(chat_prompt_list) |
| agent_prompt = ChatPromptTemplate.from_messages(agent_prompt_list) |
|
|
| if self.tools: |
| agent = create_openai_tools_agent( |
| self.llm, self.tools, agent_prompt |
| ) |
| agent_executor = AgentExecutor( |
| agent=agent, tools=self.tools, verbose=False |
| ) |
| else: |
| agent_executor = chat_prompt | self.llm |
|
|
| output = agent_executor.invoke( |
| { |
| "input": "\n".join( |
| [self.system_message.content] |
| + self.message_history |
| + [self.prefix] |
| ) |
| } |
| ) |
| message = output["output"] if self.tools else output.content |
| return message |
|
|
| def receive(self, name: str, message: str) -> None: |
| """ |
| Concatenate {message} spoken by {name} into message history |
| """ |
| self.message_history.append(f"{name}: {message}\n") |
|
|
|
|
| class DialogueSimulator: |
| """ |
| Class for simulating the debate. |
| """ |
|
|
| def __init__( |
| self, |
| agents: List[DialogueAgent], |
| selection_function: Callable[[int, List[DialogueAgent]], int], |
| ) -> None: |
| self.agents = agents |
| self._step = 0 |
| self.select_next_speaker = selection_function |
|
|
| def reset(self): |
| for agent in self.agents: |
| agent.reset() |
|
|
| def inject(self, name: str, message: str): |
| """ |
| Initiate the conversation with a {message} from {name} |
| """ |
| for agent in self.agents: |
| agent.receive(name, message) |
|
|
| |
| |
|
|
| def step(self) -> tuple[str, str]: |
| |
| speaker_idx = self.select_next_speaker(self._step, self.agents) |
| speaker = self.agents[speaker_idx] |
|
|
| |
| try: |
| with st.spinner(f"{speaker.name} is thinking..."): |
| message = speaker.send() |
| except Exception as e: |
| st.error(f"An error occurred: {e}", icon="🚨") |
| st.stop() |
|
|
| |
| for receiver in self.agents: |
| receiver.receive(speaker.name, message) |
|
|
| |
| self._step += 1 |
|
|
| return speaker.name, message |
|
|
|
|
| def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: |
| """ |
| Return 0, 1, ..., or (the number of agents - 1) corresponding |
| to the next speaker. |
| """ |
|
|
| idx = (step) % len(agents) |
| return idx |
| |
|
|
| def generate_speech(text, speaker_wav=None, language="en"): |
| """ |
| Generate speech using xtts-v2. Use a default voice if no speaker WAV is provided. |
| Args: |
| text (str): Text to synthesize. |
| speaker_wav (str or UploadedFile): Path to a speaker WAV file for voice cloning (optional). |
| language (str): Language of the text. |
| Returns: |
| str: Path to the generated audio file. |
| """ |
| try: |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
| tts_model.tts_to_file( |
| text=text, |
| file_path=temp_audio.name, |
| speaker_wav=speaker_wav.name if speaker_wav else None, |
| language=language, |
| ) |
| return temp_audio.name |
| except Exception as e: |
| st.error(f"Error generating speech: {e}") |
| return None |
|
|
|
|
| def run_simulator(no_of_rounds: int, simulator: DialogueSimulator) -> None: |
| """ |
| Simulate a given number of rounds for the debate. |
| Add TTS audio playback for each dialogue generated. |
| """ |
| max_iters = 2 * no_of_rounds |
| iter = 0 |
|
|
| |
| positive_speaker_wav = st.session_state.get("positive_speaker_wav", None) |
| negative_speaker_wav = st.session_state.get("negative_speaker_wav", None) |
|
|
| while iter < max_iters: |
| |
| name, message = simulator.step() |
| color = "blue" if iter % 2 == 0 else "red" |
| message4print = f"**:{color}[{name}]**: {message}" |
|
|
| |
| st.session_state.conversations.append(f"{name}: {message}") |
| st.session_state.conversations4print.append(message4print) |
|
|
| |
| st.write(message4print) |
|
|
| |
| speaker_wav = ( |
| positive_speaker_wav if name == st.session_state.positive else negative_speaker_wav |
| ) |
| audio_path = generate_speech(message, speaker_wav, st.session_state.language) |
|
|
| |
| if audio_path: |
| st.audio(audio_path, format="audio/wav") |
|
|
| |
| with open(audio_path, "rb") as audio_file: |
| st.download_button( |
| label=f"Download {name}'s Response Audio", |
| data=audio_file, |
| file_name=f"{name}_response.wav", |
| mime="audio/wav", |
| ) |
|
|
| iter += 1 |
|
|
|
|
| def generate_agent_description( |
| name: str, |
| conversation_description: str, |
| language: Literal['English', 'Korean'], |
| word_limit: int |
| ) -> str: |
|
|
| """ |
| Generate the description for a participant. |
| """ |
|
|
| agent_specifier_prompt = [ |
| SystemMessage( |
| content=( |
| "You can add detail to the description of " |
| "the conversation participant." |
| ) |
| ), |
| HumanMessage( |
| content=( |
| f"{conversation_description}\n" |
| f"Please reply with a creative description of '{name}', " |
| f"in {word_limit} words or less in {language}.\n" |
| f"Speak directly to '{name}'.\n" |
| "Give them a point of view.\n" |
| "Do not add anything else." |
| ) |
| ), |
| ] |
| agent_specifier_llm = ChatOpenAI( |
| model=st.session_state.model, temperature=1.0 |
| ) |
| agent_description = agent_specifier_llm.invoke(agent_specifier_prompt) |
|
|
| return agent_description.content |
|
|
|
|
| def generate_system_message( |
| name: str, |
| conversation_description: str, |
| description: str, |
| language: Literal['English', 'Korean'], |
| word_limit: int |
| ) -> str: |
|
|
| """ |
| Generate the system message for a participant. |
| """ |
|
|
| if description: |
| description_statement = ( |
| f"Your description is as follows: {description}\n\n" |
| ) |
| else: |
| description_statement = "" |
|
|
| generated_system_message = ( |
| f"{conversation_description}\n\n" |
| f"Your name is '{name}'.\n\n" |
| f"{description_statement}" |
| "Your goal is to persuade your conversation partner " |
| "of your point of view.\n\n" |
| "DO look up information with your tool " |
| "to refute your partner's claims.\n" |
| "DO cite your sources.\n\n" |
| "DO NOT fabricate fake citations.\n" |
| "DO NOT cite any source that you did not look up.\n\n" |
| "DO NOT restate something that has been said in the past.\n" |
| "Do not add anything else.\n\nStop speaking the moment " |
| "you finish speaking from your perspective.\n\n" |
| f"Answer in {word_limit} words or less in {language}." |
| ) |
| return generated_system_message |
|
|
|
|
| def get_participant_names(topic: str) -> List[str]: |
| """ |
| Get the names of the positive and negative for the debate. |
| """ |
|
|
| participants = ["positive", "negative"] |
| participant_names = [] |
|
|
| for participant in participants: |
| ex = "AI alarmist" if participant == "negative" else "AI accelerationist" |
| name_specifier_prompt = [ |
| SystemMessage(content="You are a helpful moderator for a debate."), |
| HumanMessage( |
| content=( |
| "Here is the topic of conversation: " |
| f"{append_period(topic)}\n" |
| f"For the {participant} perspective on the topic, " |
| "write a name in three words or less. Start the name " |
| "with a capital letter and do not use ':' .\n" |
| "For example, for the topic 'The current impact of " |
| "automation and artificial intelligence on employment', " |
| f"'{ex}' could serve as an appropriate name for " |
| f"the {participant} side.\n" |
| "Use a common noun instead of a proper noun, " |
| "as shown in the example." |
| ) |
| ), |
| ] |
| name_specifier_llm = ChatOpenAI( |
| model=st.session_state.model, temperature=1.0 |
| ) |
| participant_name = name_specifier_llm.invoke( |
| name_specifier_prompt |
| ).content |
| participant_names.append(participant_name) |
|
|
| return participant_names |
|
|
|
|
| def continue_debate() -> None: |
| """ |
| Unset the new debate flag to signal that the debate has been set up. |
| """ |
|
|
| st.session_state.new_debate = False |
|
|
|
|
| def reset_debate() -> None: |
| """ |
| Reset all the session state variables. |
| """ |
|
|
| st.session_state.topic = "" |
| st.session_state.language = "English" |
| st.session_state.positive = "" |
| st.session_state.negative = "" |
| st.session_state.agent_descriptions = {} |
| st.session_state.specified_topic = "" |
| st.session_state.new_debate = True |
| st.session_state.conversations = [] |
| st.session_state.conversations4print = [] |
| st.session_state.simulator = None |
| st.session_state.names = {} |
| st.session_state.tools = [] |
| st.session_state.retriever_tool = None |
| st.session_state.vector_store_message = "" |
| st.session_state.conclusions = "" |
| st.session_state.comments_key = 0 |
|
|
|
|
| def set_tools() -> None: |
| """ |
| Set the tools for the agents. Tools that can be selected are |
| bing_search, arxiv, and retrieval. |
| """ |
|
|
| class MySearchToolInput(BaseModel): |
| query: str = Field(description="search query to look up") |
|
|
| arxiv = load_tools(["arxiv"])[0] |
| wikipedia = load_tools(["wikipedia"])[0] |
|
|
| tool_options = ["ArXiv", "Wikipedia", "Retrieval"] |
| tool_dictionary = {"ArXiv": arxiv, "Wikipedia": wikipedia} |
|
|
| if st.session_state.bing_subscription_validity: |
| search = BingSearchAPIWrapper() |
| bing_search = Tool( |
| name="bing_search", |
| description=( |
| "A search engine for comprehensive, accurate, and trusted results. " |
| "Useful for when you need to answer questions about current events. " |
| "Input should be a search query." |
| ), |
| func=partial(search.results, num_results=5), |
| args_schema=MySearchToolInput, |
| ) |
| tool_options.insert(0, "Search") |
| tool_dictionary["Search"] = bing_search |
|
|
| st.write("**Tools**") |
| st.session_state.selected_tools = st.multiselect( |
| label="agent tools", |
| options=tool_options, |
| label_visibility="collapsed", |
| ) |
| if "Search" not in tool_options: |
| st.write( |
| "<small>To search the internet, obtain your Bing Subscription " |
| "Key [here](https://portal.azure.com/) and enter it in the " |
| "sidebar. Once entered, 'Search' will be displayed in the " |
| "list of tools.</small>", |
| unsafe_allow_html=True, |
| ) |
| if "Retrieval" in st.session_state.selected_tools: |
| |
| get_retriever() |
| if st.session_state.retriever_tool is not None: |
| tool_dictionary["Retrieval"] = st.session_state.retriever_tool |
| else: |
| st.session_state.selected_tools.remove("Retrieval") |
|
|
| st.session_state.tools = [ |
| tool_dictionary[key] for key in st.session_state.selected_tools |
| ] |
|
|
|
|
| def set_debate() -> None: |
| """ |
| Prepare the agents for the debate by setting the topic, names, |
| descriptions of the participants, and the questions for the debate, |
| uploading speaker WAVs, and allowing the use of default voices. |
| """ |
| st.write("**Upload Speaker WAV Files** (Optional)") |
| |
| st.session_state.positive_speaker_wav = st.file_uploader( |
| label="Upload WAV for Positive Debater (Optional)", |
| type=["wav"], |
| key="positive_speaker", |
| ) |
| st.session_state.negative_speaker_wav = st.file_uploader( |
| label="Upload WAV for Negative Debater (Optional)", |
| type=["wav"], |
| key="negative_speaker", |
| ) |
|
|
| st.write("**Topic of the debate**") |
| topic = st.text_input( |
| label="topic of the debate", |
| placeholder="Enter your topic", |
| value=st.session_state.topic, |
| label_visibility="collapsed", |
| ) |
| st.session_state.topic = topic.strip() |
| st.write( |
| "**Language** " |
| "<small>used by the debaters</small>", |
| unsafe_allow_html=True |
| ) |
| st.session_state.language = st.radio( |
| label="language", |
| options=("English", "Hindi", "Spanish", "French", "Chinese", "Korean", "Japanese"), |
| label_visibility="collapsed", |
| index=1, |
| horizontal=True |
| ) |
| st.write("**Model**") |
| st.session_state.model = st.radio( |
| label="Model", |
| options=("gpt-4o-mini", "gpt-4o"), |
| label_visibility="collapsed", |
| horizontal=True, |
| index=1, |
| ) |
|
|
| |
| set_tools() |
|
|
| left, right = st.columns(2) |
| left.write("**Word limit for question suggestions** (≥ 10)") |
| |
| description_word_limit = left.number_input( |
| label="description_word_limit", |
| min_value=10, |
| max_value=500, |
| value=20, |
| step=10, |
| label_visibility="collapsed" |
| ) |
|
|
| right.write("**Word limit for each debate response** (≥ 50)") |
| |
| st.session_state.word_limit = right.number_input( |
| label="answer_word_limit", |
| min_value=50, |
| max_value=2000, |
| value=100, |
| step=50, |
| label_visibility="collapsed" |
| ) |
|
|
| if st.button("Suggest names for the debaters"): |
| st.session_state.positive, st.session_state.negative = ( |
| get_participant_names(topic) |
| ) |
|
|
| left, right = st.columns(2) |
| left.write("**Name for the positive**") |
| positive = left.text_input( |
| label="name of the positive", |
| value=st.session_state.positive, |
| label_visibility="collapsed" |
| ) |
| st.session_state.positive = positive |
|
|
| right.write("**Name for the negative**") |
| negative = right.text_input( |
| label="name of the negative", |
| value=st.session_state.negative, |
| label_visibility="collapsed" |
| ) |
| st.session_state.negative = negative |
|
|
| st.session_state.names = { |
| positive: st.session_state.tools, |
| negative: st.session_state.tools, |
| } |
| conversation_description = ( |
| "Here is the topic of conversation: " |
| f"{append_period(topic)}\nThe participants are: " |
| f"{' and '.join(st.session_state.names.keys())}." |
| ) |
|
|
| agent_descriptions, agent_system_messages = {}, {} |
|
|
| if positive and negative: |
| if st.button("Suggest descriptions for the debaters"): |
| for name in st.session_state.names.keys(): |
| st.session_state.agent_descriptions[name] = ( |
| generate_agent_description( |
| name, |
| conversation_description, |
| st.session_state.language, |
| description_word_limit |
| ) |
| ) |
|
|
| for name in st.session_state.names.keys(): |
| st.write(f"**Description for {name}**") |
| agent_descriptions[name] = st.text_area( |
| label=f"description for {name}", |
| value=st.session_state.agent_descriptions.get(name, ""), |
| label_visibility="collapsed" |
| ) |
| st.session_state.agent_descriptions[name] = agent_descriptions[name] |
|
|
| keys_to_delete = [ |
| key for key in st.session_state.agent_descriptions |
| if key not in st.session_state.names |
| ] |
| for key in keys_to_delete: |
| del st.session_state.agent_descriptions[key] |
|
|
| for name in st.session_state.names.keys(): |
| agent_system_messages[name] = generate_system_message( |
| name, |
| conversation_description, |
| agent_descriptions[name], |
| st.session_state.language, |
| st.session_state.word_limit |
| ) |
|
|
| if st.button("Suggest questions for the debaters"): |
| topic_specifier_prompt = [ |
| SystemMessage(content="You can make a topic more specific."), |
| HumanMessage( |
| content=( |
| "Here is the topic of conversation: " |
| f"{append_period(topic)}\n" |
| "You are the moderator.\n" |
| "Please make the topic more specific.\n" |
| "Please reply with the specified quest in " |
| f"{description_word_limit} words or less in " |
| f"{st.session_state.language}.\n" |
| "Speak directly to the participants: " |
| f"{*st.session_state.names,}.\n" |
| "Do not add anything else." |
| ) |
| ), |
| ] |
| topic_specifier_llm = ChatOpenAI( |
| model=st.session_state.model, temperature=1.0 |
| ) |
| st.session_state.specified_topic = topic_specifier_llm.invoke( |
| topic_specifier_prompt |
| ).content |
|
|
| st.write("**Questions for the debaters**") |
| specified_topic = st.text_area( |
| label="questions for the debaters", |
| value=st.session_state.specified_topic, |
| label_visibility="collapsed", |
| ) |
| st.session_state.specified_topic = specified_topic |
|
|
| if st.session_state.specified_topic: |
| if st.button("Prepare the debate"): |
| agent_llm = ChatOpenAI( |
| model=st.session_state.model, temperature=0.2 |
| ) |
| agents = [ |
| DialogueAgent( |
| name=name, |
| system_message=SystemMessage(content=system_message), |
| llm=agent_llm, |
| tools=tools, |
| ) |
| for (name, tools), system_message in zip( |
| st.session_state.names.items(), |
| agent_system_messages.values() |
| ) |
| ] |
| st.session_state.simulator = DialogueSimulator( |
| agents=agents, selection_function=select_next_speaker |
| ) |
| st.session_state.simulator.reset() |
| st.session_state.simulator.inject("Moderator", specified_topic) |
| st.session_state.new_debate = False |
| st.rerun() |
|
|
|
|
| def print_topic_debaters_questions() -> str: |
| """ |
| Print the topic, the names and descriptions of the participants, |
| and questions for the debate. |
| """ |
|
|
| st.write("**Topic of the debate**") |
| st.info(f"**{st.session_state.topic}**") |
| st.write( |
| "**Name for the positive**$\:$: " |
| f"$~$:blue[{st.session_state.positive}]" |
| ) |
| st.write( |
| "**Name for the negative**: " |
| f"$~$:blue[{st.session_state.negative}]" |
| ) |
| agent_descriptions = st.session_state.agent_descriptions |
| dict_name = "agent_descriptions" |
| for name in st.session_state.names.keys(): |
| st.write(f"**Description for {name}**") |
| st.info(agent_descriptions[name]) |
|
|
| st.write("**Moderator**: Here are the questions for the debaters") |
| st.info(st.session_state.specified_topic) |
|
|
| headers = ( |
| "Topic of the debate: " |
| f"{st.session_state.topic}\n\n" |
| "Name for the positive: " |
| f"{st.session_state.positive}\n" |
| "Name for the negative: " |
| f"{st.session_state.negative}\n\n" |
| ) |
|
|
| if agent_descriptions[st.session_state.positive]: |
| headers += ( |
| f"Description for {st.session_state.positive}:\n" |
| f"{locals()[dict_name][st.session_state.positive]}\n\n" |
| ) |
| if agent_descriptions[st.session_state.negative]: |
| headers += ( |
| f"Description for {st.session_state.negative}:\n" |
| f"{locals()[dict_name][st.session_state.negative]}\n\n" |
| ) |
|
|
| headers += f"Moderator: {st.session_state.specified_topic}\n\n" |
| return headers |
|
|
|
|
| def conclude_debate() -> None: |
| """ |
| End the debate by providing a summary of the points raised by |
| each participant and making a concluding remark. Add this conclusion |
| to the list of conversations. |
| """ |
|
|
| word_limit = 2 * st.session_state.word_limit |
| moderator_prompt = [ |
| SystemMessage( |
| content=( |
| "You are the Moderator. " |
| "Your goal is to provide a comprehensive summary " |
| "highlighting the key points raised by each participant, " |
| "and then to conclude the debate in a productive manner. " |
| "If there is a clear standout in terms of being more " |
| "persuasive or convincing, mention this in your conclusion." |
| ) |
| ), |
| HumanMessage( |
| content=( |
| f"Answer in {word_limit} words or less " |
| f"in {st.session_state.language}.\n\n" |
| "Here is the complete conversation.\n\n" |
| f"{st.session_state.complete_conversations}\n\n" |
| "Moderator: " |
| ) |
| ), |
| ] |
| moderator_llm = ChatOpenAI( |
| model=st.session_state.model, temperature=0.2 |
| ) |
| with st.spinner("Moderator is thinking..."): |
| st.session_state.conclusions = moderator_llm.invoke( |
| moderator_prompt |
| ).content |
|
|
| st.session_state.conversations.append( |
| f"Moderator: {st.session_state.conclusions}" |
| ) |
| st.session_state.conversations4print.append( |
| f"**Moderator**: {st.session_state.conclusions}" |
| ) |
|
|
|
|
| def multi_agent_debate(): |
| """ |
| Let two agents, equipped with tools such as bing search, arxiv, |
| and retriever, debate on a given topic. The debate can be concluded |
| with a remark and be downloaded. |
| """ |
|
|
| page_title = "Multi-lingual Multi-Agent Debate" |
| page_icon = "📚" |
|
|
| st.set_page_config( |
| page_title=page_title, |
| page_icon=page_icon, |
| layout="centered" |
| ) |
|
|
| |
| st.write(f"## {page_icon} $\,${page_title}") |
|
|
| |
| st.image("./files/image-3.png", caption=" ", use_container_width=True) |
| st.info( |
| """ |
| **Acknowledgment**: This project is inspired by [Twy's Work](https://github.com/twy80/Multi_Agent_Debate). |
| """ |
| ) |
|
|
| |
| initialize_session_state_variables() |
|
|
| |
| with st.sidebar: |
| st.write("### API Key Status") |
| openai_key_status = "✔️ Available" if os.getenv("OPENAI_API_KEY") else "❌ Missing" |
| bing_key_status = "✔️ Available" if os.getenv("BING_SUBSCRIPTION_KEY") else "❌ Missing" |
|
|
| st.write(f"**OpenAI Key**: {openai_key_status}") |
| st.write(f"**Bing Key**: {bing_key_status}") |
|
|
| |
| if not os.getenv("OPENAI_API_KEY"): |
| st.error("OpenAI Key is required for this app to function properly.") |
|
|
| |
| if st.session_state.new_debate: |
| set_debate() |
| else: |
| with st.sidebar: |
| st.write("") |
| st.write(f"**Model**: :blue[{st.session_state.model}]") |
| st.write(f"**Language**: :blue[{st.session_state.language}]") |
| st.write(f"**Word limit**: :blue[{st.session_state.word_limit}]") |
|
|
| if st.session_state.selected_tools: |
| used_tools = ( |
| f":blue[{', '.join(st.session_state.selected_tools)}]" |
| ) |
| if len(st.session_state.selected_tools) == 1: |
| st.write(f"**Tool**: {used_tools}") |
| else: |
| st.write(f"**Tools**: {used_tools}") |
| else: |
| st.write(f"**Tool**: :blue[None]") |
|
|
| headers = print_topic_debaters_questions() |
| st.session_state.complete_conversations = ( |
| headers + "\n\n".join(st.session_state.conversations) |
| ) |
|
|
| if st.session_state.conversations: |
| label_debate = "$\,$Continue the debate$\,$" |
| label_no_of_rounds = "Number of additional rounds" |
| value_no_of_rounds = 1 |
| else: |
| label_debate = "$~~~\,$Start the debate$~~~\,$" |
| label_no_of_rounds = "Number of rounds in this debate" |
| value_no_of_rounds = 5 |
|
|
| st.write("") |
| for message in st.session_state.conversations4print: |
| st.write(message) |
|
|
| if not st.session_state.conclusions: |
| st.write(f"**{label_no_of_rounds}**") |
| c1, _, _ = st.columns(3) |
| no_of_rounds = c1.number_input( |
| label=f"{label_no_of_rounds}", |
| min_value=1, |
| max_value=10, |
| value=value_no_of_rounds, |
| step=1, |
| label_visibility="collapsed", |
| ) |
|
|
| if st.session_state.conversations and not st.session_state.conclusions: |
| st.write("**Facilitative comments by the (human) moderator** (Optional)") |
| facilitative_comments = st.text_input( |
| label="facilitative_comments", |
| value="", |
| key="comments" + str(st.session_state.comments_key), |
| label_visibility="collapsed", |
| ) |
| if facilitative_comments: |
| st.session_state.simulator.inject("Moderator", facilitative_comments) |
| st.session_state.conversations.append( |
| f"Moderator: {facilitative_comments}" |
| ) |
| st.session_state.conversations4print.append( |
| f"**Moderator**: {facilitative_comments}" |
| ) |
| st.session_state.comments_key += 1 |
|
|
| left, right = st.columns(2) |
|
|
| if not st.session_state.conclusions: |
| if left.button(f"{label_debate}"): |
| run_simulator(no_of_rounds, st.session_state.simulator) |
| st.rerun() |
| if st.session_state.conversations: |
| if right.button("Conclude the debate$\,$"): |
| conclude_debate() |
| st.rerun() |
| else: |
| if right.button("$~\:$Back to the setting$~\:$"): |
| st.session_state.new_debate = True |
| st.rerun() |
|
|
| left.download_button( |
| label="Download the debate", |
| data=st.session_state.complete_conversations, |
| file_name="multi_agent_debate.txt", |
| mime="text/plain" |
| ) |
| right.button( |
| label="$~~\,\:\!$Reset the debate$~~\,\,$", |
| on_click=reset_debate |
| ) |
|
|
| if __name__ == "__main__": |
| multi_agent_debate() |
|
|