AI_Assistant / app.py
sid22669's picture
Update app.py
87d0d98 verified
raw
history blame
2.46 kB
import gradio as gr
from langchain.chains import create_retrieval_chain
from langchain.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.memory.chat_message_histories import ChatMessageHistory
from langchain_openai import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.embeddings import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
persist_directory = 'vec_db'
vectordb = Chroma(persist_directory=persist_directory,
embedding_function=embedding_model)
vectordb_retriever = vectordb.as_retriever(search_kwargs={'k':5})
llm = ChatOpenAI(model="gpt-4.1-nano", temperature=0.7)
with open("instructions.txt", 'r') as file:
instructions = file.read()
# Custom prompt
custom_prompt = ChatPromptTemplate.from_messages([
("system", instructions),
MessagesPlaceholder(variable_name="chat_history"),
("user", "Question: {input}\nContext: {context}")
])
# Memory
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
question_answer_chain = create_stuff_documents_chain(llm, custom_prompt)
chain = create_retrieval_chain(vectordb_retriever, question_answer_chain)
def conversate_assistant(query, history):
greetings = {"hey", "hi", "hello"}
normalized_query = query.strip().lower()
if len(memory.load_memory_variables({})["chat_history"]) >=6:
chat_history = memory.load_memory_variables({})["chat_history"][-6::]
else:
chat_history = memory.load_memory_variables({})["chat_history"]
# If greeting, skip retrieval and context
if normalized_query in greetings:
response = question_answer_chain.invoke({
"input": query,
"context": [], # empty context for greetings
"chat_history": chat_history
})
answer = response
else:
response = chain.invoke({
"input": query,
"chat_history": chat_history
})
answer = response['answer']
# Save to memory
memory.save_context({"input": query}, {"output": answer})
return answer
demo = gr.ChatInterface(
conversate_assistant,
type="messages"
)
demo.launch()