| import streamlit as st |
| from io import BytesIO |
| import ibm_watsonx_ai |
| import secretsload |
| import genparam |
| import requests |
| import time |
| import re |
| import json |
|
|
| from ibm_watsonx_ai.foundation_models import ModelInference |
| from ibm_watsonx_ai import Credentials, APIClient |
| from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams |
| from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams |
|
|
| from ibm_watsonx_ai.foundation_models import Embeddings |
| from ibm_watsonx_ai.foundation_models.utils.enums import EmbeddingTypes |
| from pymilvus import MilvusClient |
|
|
| from secretsload import load_stsecrets |
|
|
| credentials = load_stsecrets() |
|
|
| st.set_page_config( |
| page_title="The Solutioning Sages", |
| page_icon="🪄", |
| initial_sidebar_state="collapsed", |
| layout="wide" |
| ) |
|
|
| |
| def check_password(): |
| def password_entered(): |
| if st.session_state["password"] == st.secrets["app_password"]: |
| st.session_state["password_correct"] = True |
| del st.session_state["password"] |
| else: |
| st.session_state["password_correct"] = False |
|
|
| if "password_correct" not in st.session_state: |
| st.markdown("\n\n") |
| st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
| st.divider() |
| st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") |
| return False |
| elif not st.session_state["password_correct"]: |
| st.markdown("\n\n") |
| st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
| st.divider() |
| st.error("😕 Incorrect password") |
| st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024") |
| return False |
| else: |
| return True |
|
|
| def initialize_session_state(): |
| if 'chat_history_1' not in st.session_state: |
| st.session_state.chat_history_1 = [] |
| if 'chat_history_2' not in st.session_state: |
| st.session_state.chat_history_2 = [] |
| if 'chat_history_3' not in st.session_state: |
| st.session_state.chat_history_3 = [] |
| if 'first_question' not in st.session_state: |
| st.session_state.first_question = False |
| if "counter" not in st.session_state: |
| st.session_state["counter"] = 0 |
| if 'token_capture' not in st.session_state: |
| st.session_state.token_capture = [] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| three_column_style = """ |
| <style> |
| .stColumn { |
| padding: 0.5rem; |
| border-right: 1px solid #dedede; |
| } |
| .stColumn:last-child { |
| border-right: none; |
| } |
| .chat-container { |
| height: calc(100vh - 200px); |
| overflow-y: auto; |
| display: flex; |
| flex-direction: column; |
| } |
| .chat-messages { |
| display: flex; |
| flex-direction: column; |
| gap: 1rem; |
| } |
| </style> |
| """ |
|
|
| |
| def get_active_model(): |
| return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2 |
|
|
| def get_active_prompt_template(): |
| return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2 |
|
|
| def get_active_vector_index(): |
| return st.secrets["vector_index_id_1"] if genparam.ACTIVE_INDEX == 0 else st.secrets["vector_index_id_2"] |
| |
|
|
| def setup_client(project_id): |
| credentials = Credentials( |
| url=st.secrets["url"], |
| api_key=st.secrets["api_key"] |
| ) |
| apo = st.secrets["api_key"] |
| client = APIClient(credentials, project_id=project_id) |
| return credentials, client |
|
|
| wml_credentials, client = setup_client(st.secrets["project_id"]) |
|
|
| def setup_vector_index(client, wml_credentials, vector_index_id): |
| vector_index_details = client.data_assets.get_details(vector_index_id) |
| vector_index_properties = vector_index_details["entity"]["vector_index"] |
|
|
| emb = Embeddings( |
| model_id=vector_index_properties["settings"]["embedding_model_id"], |
| |
| credentials=wml_credentials, |
| project_id=st.secrets["project_id"], |
| params={ |
| "truncate_input_tokens": 512 |
| } |
| ) |
| |
| vector_store_schema = vector_index_properties["settings"]["schema_fields"] |
| connection_details = client.connections.get_details(vector_index_details["entity"]["vector_index"]["store"]["connection_id"]) |
| connection_properties = connection_details["entity"]["properties"] |
| |
| milvus_client = MilvusClient( |
| uri=f'https://{connection_properties.get("host")}:{connection_properties.get("port")}', |
| user=connection_properties.get("username"), |
| password=connection_properties.get("password"), |
| db_name=vector_index_properties["store"]["database"] |
| ) |
|
|
| return milvus_client, emb, vector_index_properties, vector_store_schema |
|
|
| def proximity_search(question, milvus_client, emb, vector_index_properties, vector_store_schema): |
| query_vectors = emb.embed_query(question) |
| milvus_response = milvus_client.search( |
| collection_name=vector_index_properties["store"]["index"], |
| data=[query_vectors], |
| limit=vector_index_properties["settings"]["top_k"], |
| metric_type="L2", |
| output_fields=[ |
| vector_store_schema.get("text"), |
| vector_store_schema.get("document_name"), |
| vector_store_schema.get("page_number") |
| ] |
| ) |
| |
| documents = [] |
| |
| for hit in milvus_response[0]: |
| text = hit["entity"].get(vector_store_schema.get("text"), "") |
| doc_name = hit["entity"].get(vector_store_schema.get("document_name"), "Unknown Document") |
| page_num = hit["entity"].get(vector_store_schema.get("page_number"), "N/A") |
| |
| formatted_result = f"Document: {doc_name}\nContent: {text}\nPage: {page_num}\n" |
| documents.append(formatted_result) |
|
|
| joined = "\n".join(documents) |
| retrieved = f"""Number of Retrieved Documents: {len(documents)}\n\n{joined}""" |
|
|
| return retrieved |
|
|
| def prepare_prompt(prompt, chat_history): |
| if genparam.TYPE == "chat" and chat_history: |
| chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history]) |
| prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nConversation History:\n{chats}\n\nNew User Input: {prompt}""" |
| return prompt |
| else: |
| prompt = f"""Retrieved Contextual Information:\n__grounding__\n\nUser Input: {prompt}""" |
| return prompt |
|
|
| def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax): |
| model_family_syntax = { |
| "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
| "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
| "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""", |
| "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""", |
| "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""", |
| "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""", |
| "no syntax - system": """{system_prompt}\n\n{prompt}""", |
| "no syntax - user": """{prompt}""" |
| } |
| |
| if bake_in_prompt_syntax: |
| template = model_family_syntax[prompt_template] |
| if system_prompt: |
| return template.format(system_prompt=system_prompt, prompt=prompt) |
| return prompt |
|
|
| def generate_response(watsonx_llm, prompt_data, params): |
| generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params) |
| for chunk in generated_response: |
| yield chunk |
|
|
| def fetch_response(user_input, milvus_client, emb, vector_index_properties, vector_store_schema, system_prompt, chat_history): |
| grounding = proximity_search( |
| question=user_input, |
| milvus_client=milvus_client, |
| emb=emb, |
| vector_index_properties=vector_index_properties, |
| vector_store_schema=vector_store_schema |
| ) |
| prompt = prepare_prompt(user_input, chat_history) |
|
|
| prompt_data = apply_prompt_syntax( |
| prompt, |
| system_prompt, |
| get_active_prompt_template(), |
| genparam.BAKE_IN_PROMPT_SYNTAX |
| ) |
|
|
| prompt_data = prompt_data.replace("__grounding__", grounding) |
| |
| watsonx_llm = ModelInference( |
| api_client=client, |
| model_id=get_active_model(), |
| verify=genparam.VERIFY |
| ) |
|
|
| params = { |
| GenParams.DECODING_METHOD: genparam.DECODING_METHOD, |
| GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS, |
| GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS, |
| GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY, |
| GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES |
| } |
|
|
| bot_name = None |
| bot_avatar = None |
| if chat_history == st.session_state.chat_history_1: |
| bot_name = genparam.BOT_1_NAME |
| bot_avatar = genparam.BOT_1_AVATAR |
| elif chat_history == st.session_state.chat_history_2: |
| bot_name = genparam.BOT_2_NAME |
| bot_avatar = genparam.BOT_2_AVATAR |
| else: |
| bot_name = genparam.BOT_3_NAME |
| bot_avatar = genparam.BOT_3_AVATAR |
|
|
| with st.chat_message(bot_name, avatar=bot_avatar): |
| if genparam.TOKEN_CAPTURE_ENABLED: |
| st.code(prompt_data, line_numbers=True, wrap_lines=True) |
| stream = generate_response(watsonx_llm, prompt_data, params) |
| response = st.write_stream(stream) |
| |
|
|
| if genparam.TOKEN_CAPTURE_ENABLED: |
| chat_number = len(chat_history) // 2 |
| token_calculations = capture_tokens(prompt_data, response, chat_number) |
| if token_calculations: |
| st.sidebar.code(token_calculations) |
|
|
| return response |
|
|
| def capture_tokens(prompt_data, response, chat_number): |
| if not genparam.TOKEN_CAPTURE_ENABLED: |
| return |
|
|
| watsonx_llm = ModelInference( |
| api_client=client, |
| model_id=genparam.SELECTED_MODEL, |
| verify=genparam.VERIFY |
| ) |
|
|
| input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"] |
| output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"] |
| total_tokens = input_tokens + output_tokens |
|
|
| st.session_state.token_capture.append(f"chat {chat_number}: {input_tokens} + {output_tokens} = {total_tokens}") |
|
|
| token_calculations = "\n".join(st.session_state.token_capture) |
| return token_calculations |
|
|
| def main(): |
| initialize_session_state() |
|
|
| |
| st.markdown(three_column_style, unsafe_allow_html=True) |
| |
| |
| st.sidebar.header('The Solutioning Sages') |
| st.sidebar.write('') |
| |
| |
| st.sidebar.subheader("Your Questions") |
| for i, message in enumerate(st.session_state.chat_history_1): |
| if message["role"] == "user": |
| st.sidebar.markdown(f"**Question {i//2 + 1}:** {message['content']}") |
| |
| st.sidebar.write('') |
| st.sidebar.write('') |
|
|
| if not check_password(): |
| st.stop() |
|
|
| |
| user_input = st.chat_input("Ask your question here", key="user_input") |
| |
| if user_input: |
| |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
| st.subheader(f"{genparam.BOT_1_AVATAR} {genparam.BOT_1_NAME}") |
| st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
| |
| |
| for message in st.session_state.chat_history_1: |
| if message["role"] != "user": |
| with st.chat_message(message["role"], avatar=genparam.BOT_1_AVATAR): |
| st.markdown(message['content']) |
| |
| |
| st.session_state.chat_history_1.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
| milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
| client, |
| wml_credentials, |
| get_active_vector_index() |
| ) |
| system_prompt = genparam.BOT_1_PROMPT |
| |
| response = fetch_response( |
| user_input, |
| milvus_client, |
| emb, |
| vector_index_properties, |
| vector_store_schema, |
| system_prompt, |
| st.session_state.chat_history_1 |
| ) |
| st.session_state.chat_history_1.append({"role": genparam.BOT_1_NAME, "content": response, "avatar": genparam.BOT_1_AVATAR}) |
| st.markdown("</div></div>", unsafe_allow_html=True) |
| |
| with col2: |
| st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
| st.subheader(f"{genparam.BOT_2_AVATAR} {genparam.BOT_2_NAME}") |
| st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
| |
| |
| for message in st.session_state.chat_history_2: |
| if message["role"] != "user": |
| with st.chat_message(message["role"], avatar=genparam.BOT_2_AVATAR): |
| st.markdown(message['content']) |
| |
| |
| st.session_state.chat_history_2.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
| milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
| client, |
| wml_credentials, |
| get_active_vector_index() |
| ) |
| system_prompt = genparam.BOT_2_PROMPT |
| |
| response = fetch_response( |
| user_input, |
| milvus_client, |
| emb, |
| vector_index_properties, |
| vector_store_schema, |
| system_prompt, |
| st.session_state.chat_history_2 |
| ) |
| st.session_state.chat_history_2.append({"role": genparam.BOT_2_NAME, "content": response, "avatar": genparam.BOT_2_AVATAR}) |
| st.markdown("</div></div>", unsafe_allow_html=True) |
| |
| with col3: |
| st.markdown("<div class='chat-container'>", unsafe_allow_html=True) |
| st.subheader(f"{genparam.BOT_3_AVATAR} {genparam.BOT_3_NAME}") |
| st.markdown("<div class='chat-messages'>", unsafe_allow_html=True) |
| |
| |
| for message in st.session_state.chat_history_3: |
| if message["role"] != "user": |
| with st.chat_message(message["role"], avatar=genparam.BOT_3_AVATAR): |
| st.markdown(message['content']) |
| |
| |
| st.session_state.chat_history_3.append({"role": "user", "content": user_input, "avatar": genparam.USER_AVATAR}) |
| milvus_client, emb, vector_index_properties, vector_store_schema = setup_vector_index( |
| client, |
| wml_credentials, |
| st.secrets["vector_index_id_2"] |
| ) |
| system_prompt = genparam.BOT_3_PROMPT |
| |
| response = fetch_response( |
| user_input, |
| milvus_client, |
| emb, |
| vector_index_properties, |
| vector_store_schema, |
| system_prompt, |
| st.session_state.chat_history_3 |
| ) |
| st.session_state.chat_history_3.append({"role": genparam.BOT_3_NAME, "content": response, "avatar": genparam.BOT_3_AVATAR}) |
| st.markdown("</div></div>", unsafe_allow_html=True) |
|
|
| |
| st.sidebar.markdown("---") |
| st.sidebar.markdown("**Latest Question:**") |
| st.sidebar.markdown(f"_{user_input}_") |
|
|
| if __name__ == "__main__": |
| main() |