Update app.py
Browse files
app.py
CHANGED
|
@@ -14,12 +14,51 @@ def icon(emoji: str):
|
|
| 14 |
unsafe_allow_html=True,
|
| 15 |
)
|
| 16 |
|
|
|
|
|
|
|
| 17 |
st.subheader("Groq Chat with LLaMA3 App", divider="rainbow", anchor=False)
|
| 18 |
|
|
|
|
| 19 |
api_keys = os.environ['GROQ_API_KEYS'].split(',')
|
| 20 |
-
clients = [Groq(api_key=key) for key in api_keys]
|
| 21 |
|
| 22 |
-
with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
model_option = st.selectbox(
|
| 24 |
"Choose a model:",
|
| 25 |
options=list(models.keys()),
|
|
@@ -39,13 +78,6 @@ with st.sidebar:
|
|
| 39 |
if system_prompt := st.text_input("System Prompt"):
|
| 40 |
system_message = {"role": "system", "content": system_prompt}
|
| 41 |
|
| 42 |
-
# Initialize chat history and selected model
|
| 43 |
-
if "messages" not in st.session_state:
|
| 44 |
-
st.session_state.messages = []
|
| 45 |
-
|
| 46 |
-
if "selected_model" not in st.session_state:
|
| 47 |
-
st.session_state.selected_model = None
|
| 48 |
-
|
| 49 |
# Detect model change and clear chat history if model has changed
|
| 50 |
if st.session_state.selected_model != model_option:
|
| 51 |
st.session_state.messages = []
|
|
@@ -70,7 +102,7 @@ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
|
|
| 70 |
if prompt := st.chat_input("Enter your prompt here..."):
|
| 71 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 72 |
|
| 73 |
-
with st.chat_message("user", avatar="🧑💻"):
|
| 74 |
st.markdown(prompt)
|
| 75 |
|
| 76 |
messages=[
|
|
@@ -104,4 +136,4 @@ if prompt := st.chat_input("Enter your prompt here..."):
|
|
| 104 |
combined_response = "\n".join(str(item) for item in full_response)
|
| 105 |
st.session_state.messages.append(
|
| 106 |
{"role": "assistant", "content": combined_response}
|
| 107 |
-
)
|
|
|
|
| 14 |
unsafe_allow_html=True,
|
| 15 |
)
|
| 16 |
|
| 17 |
+
# icon("⚡️")
|
| 18 |
+
|
| 19 |
st.subheader("Groq Chat with LLaMA3 App", divider="rainbow", anchor=False)
|
| 20 |
|
| 21 |
+
# Get the API keys from the environment variable
|
| 22 |
api_keys = os.environ['GROQ_API_KEYS'].split(',')
|
|
|
|
| 23 |
|
| 24 |
+
# Initialize the Groq client with the first API key
|
| 25 |
+
client = None
|
| 26 |
+
for api_key in api_keys:
|
| 27 |
+
try:
|
| 28 |
+
client = Groq(api_key=api_key)
|
| 29 |
+
break
|
| 30 |
+
except Exception as e:
|
| 31 |
+
st.error(f"Failed to initialize client with API key {api_key}: {e}")
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
if client is None:
|
| 35 |
+
st.error("Failed to initialize client with any API key")
|
| 36 |
+
st.stop()
|
| 37 |
+
|
| 38 |
+
# Initialize chat history and selected model
|
| 39 |
+
if "messages" not in st.session_state:
|
| 40 |
+
st.session_state.messages = []
|
| 41 |
+
|
| 42 |
+
if "selected_model" not in st.session_state:
|
| 43 |
+
st.session_state.selected_model = None
|
| 44 |
+
|
| 45 |
+
# Define model details
|
| 46 |
+
models = {
|
| 47 |
+
"llama3-70b-8192": {"name": "LLaMA3-70b", "tokens": 8192, "developer": "Meta"},
|
| 48 |
+
"llama3-8b-8192": {"name": "LLaMA3-8b", "tokens": 8192, "developer": "Meta"},
|
| 49 |
+
"llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
|
| 50 |
+
"gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
|
| 51 |
+
"mixtral-8x7b-32768": {
|
| 52 |
+
"name": "Mixtral-8x7b-Instruct-v0.1",
|
| 53 |
+
"tokens": 32768,
|
| 54 |
+
"developer": "Mistral",
|
| 55 |
+
},
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Layout for model selection and max_tokens slider
|
| 59 |
+
col1, col2 = st.columns([1, 3]) # Adjust the ratio to make the first column smaller
|
| 60 |
+
|
| 61 |
+
with col1:
|
| 62 |
model_option = st.selectbox(
|
| 63 |
"Choose a model:",
|
| 64 |
options=list(models.keys()),
|
|
|
|
| 78 |
if system_prompt := st.text_input("System Prompt"):
|
| 79 |
system_message = {"role": "system", "content": system_prompt}
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Detect model change and clear chat history if model has changed
|
| 82 |
if st.session_state.selected_model != model_option:
|
| 83 |
st.session_state.messages = []
|
|
|
|
| 102 |
if prompt := st.chat_input("Enter your prompt here..."):
|
| 103 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 104 |
|
| 105 |
+
with st.chat_message("user", avatar="🧑💻"):
|
| 106 |
st.markdown(prompt)
|
| 107 |
|
| 108 |
messages=[
|
|
|
|
| 136 |
combined_response = "\n".join(str(item) for item in full_response)
|
| 137 |
st.session_state.messages.append(
|
| 138 |
{"role": "assistant", "content": combined_response}
|
| 139 |
+
)
|