|
|
|
|
|
import streamlit as st |
|
|
import config |
|
|
from openai import OpenAI |
|
|
from typing import Dict, Any, List |
|
|
|
|
|
|
|
|
def build_complete_system_prompt(selected_language: str, prompt_body: str) -> str: |
|
|
""" |
|
|
Build the complete system prompt by combining the language instruction with the prompt body. |
|
|
|
|
|
Args: |
|
|
selected_language (str): The language chosen from the dropdown. |
|
|
prompt_body (str): The editable system prompt body. |
|
|
|
|
|
Returns: |
|
|
str: The complete system prompt with language instruction. |
|
|
""" |
|
|
language_instruction = f"All your responses should be in {selected_language}.\n" |
|
|
return language_instruction + prompt_body |
|
|
|
|
|
|
|
|
def reset_chat_history(initial_context: Dict[str, str]) -> None: |
|
|
""" |
|
|
Reset the chat history to the initial context and rerun the app. |
|
|
|
|
|
Args: |
|
|
initial_context (Dict[str, str]): The initial chat context. |
|
|
""" |
|
|
st.session_state["display_messages"] = [initial_context] |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
def stream_assistant_response(client: OpenAI, model: str, messages: List[Dict[str, str]], |
|
|
temperature: float, max_tokens: int, |
|
|
frequency_penalty: float, presence_penalty: float) -> str: |
|
|
""" |
|
|
Stream the assistant's response from the OpenAI API and return the full response. |
|
|
|
|
|
Args: |
|
|
client (OpenAI): The initialized OpenAI client. |
|
|
model (str): The model to use. |
|
|
messages (List[Dict[str, str]]): The list of chat messages. |
|
|
temperature (float): The temperature setting. |
|
|
max_tokens (int): The maximum number of tokens. |
|
|
frequency_penalty (float): The frequency penalty. |
|
|
presence_penalty (float): The presence penalty. |
|
|
|
|
|
Returns: |
|
|
str: The complete assistant response. |
|
|
""" |
|
|
full_response = "" |
|
|
message_placeholder = st.empty() |
|
|
try: |
|
|
stream = client.chat.completions.create( |
|
|
model=model, |
|
|
messages=messages, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
frequency_penalty=frequency_penalty, |
|
|
presence_penalty=presence_penalty, |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
if chunk.choices[0].delta.content is not None: |
|
|
full_response += chunk.choices[0].delta.content |
|
|
message_placeholder.markdown(full_response + "▌") |
|
|
message_placeholder.markdown(full_response) |
|
|
except Exception as e: |
|
|
st.exception(e) |
|
|
st.error(f"An error occurred: {str(e)}") |
|
|
return full_response |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
"""Main function to run the Streamlit Chatbot Prompt and Parameter Tester application.""" |
|
|
|
|
|
st.set_page_config( |
|
|
layout="wide", |
|
|
page_title="OpenAI Chatbot Tester", |
|
|
page_icon=":lightbulb:", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
st.title("OpenAI Chatbot Tester") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Configuration") |
|
|
|
|
|
|
|
|
api_key = st.text_input( |
|
|
"OpenAI API Key", |
|
|
type="password", |
|
|
key="api_key", |
|
|
help="This is the API key for your OpenAI account. You can find it [here](https://platform.openai.com/api-keys)." |
|
|
) |
|
|
if api_key: |
|
|
st.session_state["OPENAI_API_KEY"] = api_key |
|
|
else: |
|
|
st.warning("Please provide your OpenAI API key to enable chat functionality.") |
|
|
|
|
|
|
|
|
model_options = ["gpt-4.1", "gpt-4o", "gpt-4o-mini"] |
|
|
default_model = config.ai_model if config.ai_model in model_options else model_options[0] |
|
|
selected_model = st.selectbox( |
|
|
"Model", |
|
|
options=model_options, |
|
|
index=model_options.index(default_model), |
|
|
help="This controls the model used for the chatbot. You can choose from the following models: gpt-4o, gpt-4o-mini" |
|
|
) |
|
|
st.session_state["openai_model"] = selected_model |
|
|
|
|
|
|
|
|
temperature = st.slider( |
|
|
"Temperature", |
|
|
0.0, 1.0, |
|
|
value=config.temperature, |
|
|
step=0.05, |
|
|
help="This controls the randomness/creativity of the responses. A higher temperature results in more creative responses." |
|
|
) |
|
|
st.session_state["temperature"] = temperature |
|
|
|
|
|
|
|
|
max_tokens = st.number_input( |
|
|
"Max Tokens", |
|
|
min_value=1, |
|
|
max_value=2048, |
|
|
value=config.max_tokens, |
|
|
step=1, |
|
|
help="This controls the maximum number of tokens the AI can generate. 1000 tokens equals about 750 words." |
|
|
) |
|
|
st.session_state["max_tokens"] = max_tokens |
|
|
|
|
|
|
|
|
frequency_penalty = st.slider( |
|
|
"Frequency Penalty", |
|
|
0.0, 1.0, |
|
|
value=config.frequency_penalty, |
|
|
step=0.05, |
|
|
help="This controls the frequency penalty for the responses. Higher values produce more diverse responses." |
|
|
) |
|
|
st.session_state["frequency_penalty"] = frequency_penalty |
|
|
|
|
|
|
|
|
presence_penalty = st.slider( |
|
|
"Presence Penalty", |
|
|
0.0, 1.0, |
|
|
value=config.presence_penalty, |
|
|
step=0.05, |
|
|
help="This controls the presence penalty for the responses. Higher values reduce repetition." |
|
|
) |
|
|
st.session_state["presence_penalty"] = presence_penalty |
|
|
|
|
|
|
|
|
language_options = [ |
|
|
"English", "Albanian", "Amharic", "Arabic", "Armenian", "Bengali", "Bosnian", "Bulgarian", "Burmese", |
|
|
"Catalan", "Chinese", "Croatian", "Czech", "Danish", "Dutch", "Estonian", "Finnish", "French", |
|
|
"Georgian", "German", "Greek", "Gujarati", "Hindi", "Hungarian", "Icelandic", "Indonesian", |
|
|
"Italian", "Japanese", "Kannada", "Kazakh", "Korean", "Latvian", "Lithuanian", "Macedonian", |
|
|
"Malay", "Malayalam", "Marathi", "Mongolian", "Norwegian", "Persian", "Polish", "Portuguese", |
|
|
"Punjabi", "Romanian", "Russian", "Serbian", "Slovak", "Slovenian", "Somali", "Spanish", |
|
|
"Swahili", "Swedish", "Tagalog", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian", "Urdu", |
|
|
"Vietnamese" |
|
|
] |
|
|
default_language = "English" |
|
|
if default_language not in language_options: |
|
|
language_options.insert(0, default_language) |
|
|
selected_language = st.selectbox( |
|
|
"Preferred Language", |
|
|
options=language_options, |
|
|
index=language_options.index(default_language), |
|
|
help="Select your preferred language for the chatbot's responses." |
|
|
) |
|
|
|
|
|
st.markdown("Click the 'Update Configuration' button each time you change an option.") |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
default_body = st.session_state.get("system_prompt_body", config.prompt) |
|
|
system_prompt_body = st.text_area( |
|
|
"System Instructions", |
|
|
value=default_body, |
|
|
key="system_prompt_body", |
|
|
help="These instructions determine the chatbot's behavior. Click 'Update Configuration' to apply changes.", |
|
|
height=400 |
|
|
) |
|
|
complete_system_prompt = build_complete_system_prompt(selected_language, system_prompt_body) |
|
|
st.session_state["system_prompt"] = complete_system_prompt |
|
|
|
|
|
|
|
|
initial_context: Dict[str, str] = { |
|
|
"role": "system", |
|
|
"content": complete_system_prompt |
|
|
} |
|
|
|
|
|
|
|
|
if st.button("Update Configuration"): |
|
|
reset_chat_history(initial_context) |
|
|
|
|
|
|
|
|
if not st.session_state.get("OPENAI_API_KEY"): |
|
|
st.error("No API key provided. Please enter your OpenAI API key above.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
client = OpenAI(api_key=st.session_state["OPENAI_API_KEY"]) |
|
|
|
|
|
|
|
|
if "openai_model" not in st.session_state: |
|
|
st.session_state["openai_model"] = config.ai_model |
|
|
|
|
|
if "display_messages" not in st.session_state: |
|
|
st.session_state["display_messages"] = [initial_context] |
|
|
|
|
|
|
|
|
prompt = st.chat_input("Type your message here...") |
|
|
if prompt: |
|
|
st.session_state["display_messages"].append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
with st.container(): |
|
|
|
|
|
for message in st.session_state["display_messages"][1:]: |
|
|
if message["role"] == "user": |
|
|
with st.chat_message("user"): |
|
|
st.markdown(message["content"]) |
|
|
else: |
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt: |
|
|
with st.chat_message("assistant"): |
|
|
full_response = stream_assistant_response( |
|
|
client=client, |
|
|
model=selected_model, |
|
|
messages=[{"role": m["role"], "content": m["content"]} |
|
|
for m in st.session_state["display_messages"]], |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
frequency_penalty=frequency_penalty, |
|
|
presence_penalty=presence_penalty |
|
|
) |
|
|
st.session_state["display_messages"].append( |
|
|
{"role": "assistant", "content": full_response} |
|
|
) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("Application created by [Keefe Reuther](https://reutherlab.netlify.app/), Assistant Teaching Professor in the UC San Diego School of Biological Sciences. " |
|
|
"Code for this app is available [here](https://huggingface.co/spaces/keefereuther/ST_basebot) and is distributed under the [GNU GPL-3 License](https://www.gnu.org/licenses/gpl-3.0.en.html).") |
|
|
|
|
|
|
|
|
with st.sidebar.expander("Debug: Session State"): |
|
|
session_state_copy = {k: v for k, v in st.session_state.items() if k not in ["display_messages", "api_key", "OPENAI_API_KEY"]} |
|
|
st.json(session_state_copy) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|