sid22669 commited on
Commit
ae6a72f
·
verified ·
1 Parent(s): b44fd2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -4,7 +4,6 @@ from langchain.vectorstores import Chroma
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
  from langchain.memory import ConversationBufferMemory
6
  from langchain.chains import ConversationalRetrievalChain
7
- from langchain.memory.chat_message_histories import ChatMessageHistory
8
  from langchain_openai import ChatOpenAI
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
10
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -14,7 +13,7 @@ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-Mi
14
  persist_directory = 'vec_db'
15
 
16
  vectordb = Chroma(persist_directory=persist_directory,
17
- embedding_function=embedding_model)
18
 
19
  vectordb_retriever = vectordb.as_retriever(search_kwargs={'k':5})
20
 
@@ -23,56 +22,62 @@ llm = ChatOpenAI(model="gpt-4.1-nano", temperature=0.7)
23
  with open("instructions.txt", 'r') as file:
24
  instructions = file.read()
25
 
26
-
27
- # Custom prompt
28
  custom_prompt = ChatPromptTemplate.from_messages([
29
  ("system", instructions),
30
  MessagesPlaceholder(variable_name="chat_history"),
31
  ("user", "Question: {input}\nContext: {context}")
32
  ])
33
 
34
- # Memory
35
- memory = ConversationBufferMemory(
36
- memory_key="chat_history",
37
- return_messages=True
38
- )
39
-
40
  question_answer_chain = create_stuff_documents_chain(llm, custom_prompt)
41
-
42
  chain = create_retrieval_chain(vectordb_retriever, question_answer_chain)
43
 
44
- def conversate_assistant(query, history):
45
- greetings = {"hey", "hi", "hello"}
46
- normalized_query = query.strip().lower()
47
 
48
- if len(memory.load_memory_variables({})["chat_history"]) >=6:
49
- chat_history = memory.load_memory_variables({})["chat_history"][-6::]
50
- else:
51
- chat_history = memory.load_memory_variables({})["chat_history"]
 
 
52
 
53
- # If greeting, skip retrieval and context
54
  if normalized_query in greetings:
55
- response = question_answer_chain.invoke({
56
  "input": query,
57
- "context": [], # empty context for greetings
58
  "chat_history": chat_history
59
  })
60
- answer = response
61
  else:
62
- response = chain.invoke({
63
  "input": query,
64
  "chat_history": chat_history
65
  })
66
- answer = response['answer']
67
-
68
- # Save to memory
69
  memory.save_context({"input": query}, {"output": answer})
70
 
71
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- demo = gr.ChatInterface(
74
- conversate_assistant,
75
- type="messages"
76
- )
77
 
78
- demo.launch()
 
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
  from langchain.memory import ConversationBufferMemory
6
  from langchain.chains import ConversationalRetrievalChain
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
9
  from langchain.embeddings import HuggingFaceEmbeddings
 
13
  persist_directory = 'vec_db'
14
 
15
  vectordb = Chroma(persist_directory=persist_directory,
16
+ embedding_function=embedding_model)
17
 
18
  vectordb_retriever = vectordb.as_retriever(search_kwargs={'k':5})
19
 
 
22
  with open("instructions.txt", 'r') as file:
23
  instructions = file.read()
24
 
 
 
25
  custom_prompt = ChatPromptTemplate.from_messages([
26
  ("system", instructions),
27
  MessagesPlaceholder(variable_name="chat_history"),
28
  ("user", "Question: {input}\nContext: {context}")
29
  ])
30
 
 
 
 
 
 
 
31
  question_answer_chain = create_stuff_documents_chain(llm, custom_prompt)
 
32
  chain = create_retrieval_chain(vectordb_retriever, question_answer_chain)
33
 
34
+ greetings = {"hey", "hi", "hello"}
 
 
35
 
36
+ def conversate_assistant(query, history, memory):
37
+ # Load limited chat history from memory
38
+ history_data = memory.load_memory_variables({})["chat_history"]
39
+ chat_history = history_data[-6:] if len(history_data) >= 6 else history_data
40
+
41
+ normalized_query = query.strip().lower()
42
 
 
43
  if normalized_query in greetings:
44
+ response = question_answer_chain({
45
  "input": query,
46
+ "context": [],
47
  "chat_history": chat_history
48
  })
49
+ answer = response.get("output") or str(response)
50
  else:
51
+ response = chain({
52
  "input": query,
53
  "chat_history": chat_history
54
  })
55
+ answer = response.get('answer') or str(response)
56
+
57
+ # Save context in this session's memory
58
  memory.save_context({"input": query}, {"output": answer})
59
 
60
+ return answer, history + [(query, answer)], memory
61
+
62
+ # Gradio interface with state for memory and chat history
63
+ with gr.Blocks() as demo:
64
+ chatbot = gr.Chatbot()
65
+ msg = gr.Textbox()
66
+ state = gr.State([]) # to keep chat history visible in UI
67
+ memory_state = gr.State(None) # to keep ConversationBufferMemory per session
68
+
69
+ def init_memory():
70
+ return ConversationBufferMemory(
71
+ memory_key="chat_history",
72
+ return_messages=True
73
+ )
74
+
75
+ def respond(user_message, chat_history, memory):
76
+ if memory is None:
77
+ memory = init_memory()
78
+ answer, chat_history, memory = conversate_assistant(user_message, chat_history, memory)
79
+ return chat_history, memory, ""
80
 
81
+ msg.submit(respond, inputs=[msg, state, memory_state], outputs=[chatbot, memory_state, msg])
 
 
 
82
 
83
+ demo.launch()