AMAbot / app.py
ChristopherMarais's picture
Update app.py
7d6c46b verified
import os
import gradio as gr
from huggingface_hub import InferenceClient
from cryptography.fernet import Fernet
# --- LangChain / RAG Imports (from your first script) ---
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
# --- Core Functions (from your first script) ---
def load_decrypted_preprompt(file_path="pre_prompt.enc"):
"""
Load and decrypt the pre-prompt from the encrypted file using the key
stored in the environment variable 'KEY'.
"""
try:
key_str = os.getenv("KEY", "")
if not key_str:
print("Warning: KEY environment variable not set, using default preprompt")
return "You are AMAbot, a helpful assistant that answers questions about Christopher."
key = key_str.encode()
fernet = Fernet(key)
with open(file_path, "rb") as file:
encrypted_text = file.read()
decrypted_text = fernet.decrypt(encrypted_text)
return decrypted_text.decode("utf-8")
except Exception as e:
print(f"Error loading preprompt: {e}, using default")
return "You are AMAbot, a helpful assistant that answers questions about Christopher."
PRE_PROMPT = load_decrypted_preprompt()
DEFAULT_TEMPERATURE = 0.7
DEFAULT_MAX_TOKENS = 512
DEFAULT_TOP_K = 50
DEFAULT_TOP_P = 0.95
# Using the model from your first script
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
def load_vector_db(index_path="faiss_index", model_name="sentence-transformers/all-MiniLM-L6-v2"):
"""Load the FAISS vector database from disk."""
try:
embeddings = HuggingFaceEmbeddings(model_name=model_name)
vector_db = FAISS.load_local(
index_path,
embeddings,
allow_dangerous_deserialization=True
)
print(f"Successfully loaded vector database from {index_path}")
return vector_db
except Exception as e:
print(f"Failed to load vector database: {e}")
return None
def create_qa_prompt():
"""
Create a prompt template for QA, formatted for Zephyr/Mistral models.
This is the specific prompt format Zephyr was trained on.
"""
template = """<|system|>
You are a helpful assistant that answers questions using the context provided.
If you don't know the answer based on the context, just say that you don't know. Don't try to make up an answer.</s>
<|user|>
Context:
{context}
Question: {question}</s>
<|assistant|>
Helpful Answer:"""
return PromptTemplate(template=template, input_variables=["context", "question"])
def update_chat(message, history):
"""Append the user's message to the chat history and clear the input box."""
if history is None:
history = []
history = history.copy()
history.append({"role": "user", "content": message})
return history, message, ""
def get_assistant_response(message, history, max_tokens, temperature, top_p, qa_chain_state_dict):
"""
Generate assistant response by manually running the RAG pipeline
and using the chat_completion endpoint. This is the logic from your first script.
"""
vector_db = qa_chain_state_dict.get("vector_db")
answer = "I apologize, but I'm having trouble accessing my knowledge base right now."
if not vector_db:
print("Error: Vector DB is not available.")
history.append({"role": "assistant", "content": answer})
return history, qa_chain_state_dict
try:
# 1. Retrieve relevant documents from the vector store
retriever = vector_db.as_retriever(search_kwargs={"k": 3})
retrieved_docs = retriever.invoke(message)
# 2. Format the context for the prompt
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
# 3. Create the prompt using the correct template for Zephyr
qa_prompt_template = create_qa_prompt()
formatted_prompt = qa_prompt_template.format(context=context, question=message)
# 4. Prepare the message payload for the conversational API
messages = [
{
"role": "user",
"content": formatted_prompt,
}
]
# 5. Call the correct API endpoint
print("Attempting to call chat_completion API...")
client = InferenceClient(MODEL_NAME, token=os.getenv("HF_TOKEN", ""))
response = client.chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature if temperature > 0 else 0.1, # Temp must be > 0 for chat
top_p=top_p,
stream=False
)
# 6. Extract the answer
if response.choices and response.choices[0].message:
answer = response.choices[0].message.content.strip()
print(f"API call successful, answer length: {len(answer)}")
else:
print("API returned an empty response.")
except Exception as e:
print(f"An error occurred in get_assistant_response: {type(e).__name__} - {repr(e)}")
answer = f"I'm experiencing technical difficulties. Please try again. (Error: {str(e)[:100]})"
history.append({"role": "assistant", "content": answer})
return history, qa_chain_state_dict
# --- Initialize Components (from your first script) ---
HF_TOKEN = os.getenv("HF_TOKEN", "")
if not HF_TOKEN:
print("Warning: HF_TOKEN token not set in environment variables!")
# Load vector database
vector_db = load_vector_db("faiss_index")
# Prepare the initial state dictionary with the vector_db
qa_chain_state_initial = {"vector_db": vector_db}
# Test the vector DB setup
if vector_db:
print("Testing vector database...")
try:
test_retriever = vector_db.as_retriever(search_kwargs={"k": 1})
test_docs = test_retriever.invoke("test query")
print("Vector DB test successful, can retrieve documents")
except Exception as e:
print(f"Vector DB test failed: {e}")
# ------------------------------------------------------------------
# Gradio Interface Layout (from your second script)
# ------------------------------------------------------------------
with gr.Blocks(fill_width=True, theme=gr.themes.Default(primary_hue="sky")) as demo:
# This HTML block contains all the CSS and JS for the desired layout
gr.HTML("""
<script>
window.addEventListener("load", () => {
document.documentElement.setAttribute("data-theme", "light");
});
</script>
<style>
:root {
--primary-200: transparent !important;
color-scheme: light !important;
background-color: #fff !important;
color: #333 !important;
}
#chatbot .message.user {
background-color: #ccc !important;
color: #222 !important;
}
.gradio-container footer {
display: none !important;
}
.gradio-container {
width: 100% !important;
max-width: none !important;
margin: 0;
}
.gradio-container .fillable {
width: 100% !important;
max-width: unset !important;
margin: 0;
}
.hf-chat-input textarea:focus {
outline: none !important;
box-shadow: none !important;
border-color: #c2c2c2 !important;
}
.hf-chat-input:focus {
outline: none !important;
box-shadow: none !important;
border-color: #c2c2c2 !important;
}
.block-container {
width: 100% !important;
max-width: none !important;
}
body .gradio-container .chatbot .hf-chat-input button .textbox textarea {
background-color: #fff !important;
color: #333 !important;
}
.example-row {
flex-grow: 1 !important;
width: 100% !important;
display: flex;
flex-direction: row;
flex-wrap: wrap;
justify-content: center;
gap: 10px;
}
.input-container {
position: relative;
width: 100%;
}
.hf-chat-input {
background-color: #f9f9f9;
border: 1px solid #e0e0e0;
border-radius: 20px;
padding: 10px 50px 10px 20px;
font-size: 16px;
width: 100%;
box-sizing: border-box;
transition: border-color 0.2s ease;
}
.hf-chat-input:focus {
outline: none;
border-color: #c2c2c2;
}
.send-button {
position: absolute;
right: 10px;
top: 50%;
transform: translateY(-50%);
width: 15px !important;
height: 30px !important;
padding: 0;
background: #fff;
border: none;
border-radius: 50%;
cursor: pointer;
transition: background-color 0.2s ease;
display: flex;
align-items: center;
justify-content: center;
font-size: 16px;
line-height: 1;
}
.send-button:hover,
.send-button:focus,
.send-button:active {
background-color: #f0f0f0;
outline: none;
top: 50% !important;
transform: translateY(-50%) !important;
}
.input-row {
display: flex;
align-items: center;
width: 100%;
gap: 10px;
}
</style>
""")
# State management remains the same
qa_chain_state = gr.State(value=qa_chain_state_initial)
user_message_state = gr.State()
chatbot = gr.Chatbot(label="AMAbot", show_label=True, elem_id="chatbot", height=250, type="messages", visible=False)
with gr.Row(elem_classes="example-row", visible=True) as examples_container:
ex1 = gr.Button("Who?")
ex2 = gr.Button("Where?")
ex3 = gr.Button("What?")
ex1.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
ex2.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
ex3.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
with gr.Row(elem_classes="input-row"):
with gr.Column(elem_classes="input-container"):
user_input = gr.Textbox(
show_label=False,
placeholder="Ask AMAbot anything about Christopher",
container=False,
elem_classes="hf-chat-input"
)
send_btn = gr.Button("❯❯", elem_classes="send-button")
# Hidden inputs for model parameters
max_tokens_input = gr.Number(value=DEFAULT_MAX_TOKENS, visible=False)
temperature_input = gr.Number(value=DEFAULT_TEMPERATURE, visible=False)
top_p_input = gr.Number(value=DEFAULT_TOP_P, visible=False)
# --- Event Handlers (Unchanged, as they correctly call the functions) ---
user_input.submit(lambda: gr.update(visible=True), None, chatbot, queue=False)
send_btn.click(lambda: gr.update(visible=True), None, chatbot, queue=False)
# Submit action for text input
user_input.submit(
update_chat,
inputs=[user_input, chatbot],
outputs=[chatbot, user_message_state, user_input]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
# Click action for send button
send_btn.click(
update_chat,
inputs=[user_input, chatbot],
outputs=[chatbot, user_message_state, user_input]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
# Click actions for example buttons
ex1.click(
lambda history: update_chat("Who is Christopher?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
ex2.click(
lambda history: update_chat("Where is Christopher from?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
ex3.click(
lambda history: update_chat("What degrees does Christopher have, and what technical experience does he have?", history)[:2],
inputs=[chatbot],
outputs=[chatbot, user_message_state]
).then(
get_assistant_response,
inputs=[user_message_state, chatbot, max_tokens_input, temperature_input, top_p_input, qa_chain_state],
outputs=[chatbot, qa_chain_state]
)
if __name__ == "__main__":
demo.queue().launch(show_api=False, share=True)