File size: 8,346 Bytes
f07851c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd8ddb1
092ccce
cd8ddb1
c0f1401
cd8ddb1
f07851c
 
 
 
 
 
 
 
 
 
97b4f48
 
cd8ddb1
 
97b4f48
 
f07851c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd8ddb1
 
f07851c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

from time import sleep
import streamlit as st
from langchain_chroma import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# CONSTANTS

EMBEDDING_MODEL = 'models/text-embedding-004'
DB_DIR = './vector_db_alta_v2'
SEARCH_TYPE = 'similarity'
N_DOCS_RETRIEVED = 2
REPHRASE_MODEL = 'llama-3.3-70b-versatile'

USER_AVATAR = './images/Comp_RGB_inv.jpg'
ASSISTANT_AVATAR = './images/Comp_RGB.jpg'
NATO_FAVICON = './images/NATO_favicon.png'
SPEED = 10

# RAG RELATED FUNCTIONS

@st.cache_resource
def build_vector_database():
    embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL)
    vector_db = Chroma(collection_name='alta_handbook', persist_directory=DB_DIR, embedding_function=embeddings)
    retriever = vector_db.as_retriever(search_type=SEARCH_TYPE, search_kwargs={"k": N_DOCS_RETRIEVED})

    return retriever

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

def create_rag_chain(selected_model, retriever):
    base_template = """You are a NATO assistant for question-answering tasks.
    If you don't know the answer, respond with: "Based on my knowledge, I can't provide an answer." Keep your responses concise.
    Use the following context to answer the question:

    Context:
    {context}

    Question:
    {question}
    """
    base_prompt = ChatPromptTemplate.from_template(base_template)
    chat_llm = ChatGroq(model=selected_model)
    rag_chain = (
        {
            "context": retriever | format_docs,
            "question": RunnablePassthrough()
        }
        | base_prompt
        | chat_llm
        | StrOutputParser()
    )

    return rag_chain

@st.cache_resource
def create_rephrase_chain():
    rephrase_template = """You are a query rephraser. Given a chat history and the latest user query, your task is to rephrase the query if it implicitly references topics in the chat history.
    If the query does not reference the chat history, return it as is. Do not provide explanations, just return the rephrased or original query.

    Chat history:
    {chat_history}

    Latest user query:
    {input}
    """
    rephrase_prompt = ChatPromptTemplate.from_template(rephrase_template)
    rephraser_llm = ChatGroq(model=REPHRASE_MODEL)
    rephrase_chain = rephrase_prompt | rephraser_llm | StrOutputParser()

    return rephrase_chain

# STREAMLIT RELATED FUNCTIONS

def typewriter_effect(text, speed, allow_html):
    # https://discuss.streamlit.io/t/st-write-typewritter/43111/3
    tokens = text.split()
    container = st.empty()
    for index in range(len(tokens) + 1):
        curr_full_text = " ".join(tokens[:index])
        container.markdown(curr_full_text, unsafe_allow_html=allow_html)
        sleep(1 / speed)

def restart_chat():
    st.session_state.pop("messages", None)

# -----------------------------

# STREAMLIT

disclaimer = """This app and its author are not affiliated with NATO and do not represent the organization in any official capacity.
The content provided is based on The NATO Alternative Analysis Handbook
but is intended for general informational purposes only and does not reflect NATO's views or policies.
"""
st.set_page_config(
    page_title="NATO Chatbot",
    page_icon=NATO_FAVICON, # 'https://cdn3.emoji.gg/emojis/5667-nato.png'
    menu_items={'About': disclaimer, 'Report a Bug': 'mailto:diego.her.jimenez@gmail.com'}
)

# Custom CSS for font change
# https://fonts.google.com
st.markdown(
    """
    <style>
    @import url('https://fonts.googleapis.com/css2?family=Solway:wght@300;400;500;700;800&display=swap');

    body * {
        font-family: 'Solway', sans-serif !important;
        font-weight: 400;
        font-style: normal;
    }
    </style>
    """,
    unsafe_allow_html=True
)

# Sidebar with description and model selector
suggested_models = {
    'meta-llama/llama-4-scout-17b-16e-instruct': 'llama 4 scout',
    'llama-3.3-70b-versatile': 'llama 3.3',
    'gemma2-9b-it': 'gemma 2'
}
with st.sidebar:
    left_col, cent_col, right_col = st.columns(3)
    with cent_col:
        st.image(NATO_FAVICON, use_container_width=True)
    st.title("Nate: Your NATO QA Assistant for Alternative Analysis")
    st.markdown("""
    This AI assistant is designed to help answer questions related to Alternative Analysis, a framework for intelligent decision-making.
    Developed by NATO, this framework offers a set of techniques that can be applied across various domains.
    However, this assistant bases its responses exclusively on [The NATO Alternative Analysis Handbook](https://www.act.nato.int/wp-content/uploads/2023/05/alta-handbook.pdf)
    """)
    selected_model = st.selectbox(
        'Chat model', 
        suggested_models.keys(), 
        format_func=lambda option: suggested_models[option],
        on_change=restart_chat
    )
    st.markdown('#')
    st.markdown('#')
    st.button("Clear conversation", on_click=restart_chat, type='tertiary', icon='🗑️')


st.markdown('#')
st.markdown('#')

# ---------------------------------------------
# build vector database
retriever = build_vector_database()

# instantiate llms
rag_chain = create_rag_chain(selected_model, retriever)
rephrase_chain = create_rephrase_chain()
# ---------------------------------------------

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Suggested questions (only at the beginning)
# We display the questions with the emojis, but only the question itself is saved as input
# (see format_func below)
suggested_questions = {
    'What is Alternative Analysis?': 'What is Alternative Analysis? 🤔',
    'How can I apply SWOT step by step?': 'How can I apply SWOT step by step? 📝',
    'How can I manage dysfunctional behavior in a AltA session?': 'How can I manage dysfunctional behavior in a AltA session? ⚠️',
    'Point me to some resources to learn about Six Thinking Hats': 'Point me to some resources to learn about Six Thinking Hats 🎩'
}

if len(st.session_state.messages) == 0:
    selected_question = st.pills(
        "Suggested questions:", 
        suggested_questions.keys(),
        selection_mode='single',
        format_func=lambda option: suggested_questions[option]
    )
    if selected_question:
        st.session_state.messages.append({"role": "user", "content": selected_question, "avatar": USER_AVATAR})
        st.rerun()

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"], avatar=message.get("avatar")):
        st.markdown(message["content"])

# Ensure the chat input box is always visible
user_input = st.chat_input("Ask anything about AltA", max_chars=500)
if user_input:
    st.session_state.messages.append({"role": "user", "content": user_input, "avatar": USER_AVATAR})
    st.rerun()

# Generate assistant response if user input exists
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
    user_input = st.session_state.messages[-1]["content"]
    # Generate assistant response
    if len(st.session_state.messages) == 1:
        final_input = user_input
    # when there are previous interactions (history)...
    else:
        # ... keep only the last two interactions in "memory"
        trimmed_conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.messages[-2:]])
        # based on the context given by the memorized conversation, rephrase the current query (it might implicitly refer to previous responses or questions)
        final_input = rephrase_chain.invoke({"chat_history": trimmed_conversation, "input": user_input})
    response = rag_chain.invoke(final_input)
    # Display assistant response in chat message container
    with st.chat_message('assistant', avatar=ASSISTANT_AVATAR):
        typewriter_effect(response, speed=SPEED, allow_html=True) # unsafe_allow is True because the model sometimes outputs tables, but maybe I should remove it
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response, "avatar": ASSISTANT_AVATAR})