sid22669 commited on
Commit
f026756
·
verified ·
1 Parent(s): ae6a72f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -4,80 +4,75 @@ from langchain.vectorstores import Chroma
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
  from langchain.memory import ConversationBufferMemory
6
  from langchain.chains import ConversationalRetrievalChain
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
 
11
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
12
 
13
- persist_directory = 'vec_db'
14
 
15
- vectordb = Chroma(persist_directory=persist_directory,
16
- embedding_function=embedding_model)
17
 
18
- vectordb_retriever = vectordb.as_retriever(search_kwargs={'k':5})
19
 
20
- llm = ChatOpenAI(model="gpt-4.1-nano", temperature=0.7)
21
 
22
- with open("instructions.txt", 'r') as file:
23
  instructions = file.read()
24
 
 
 
25
  custom_prompt = ChatPromptTemplate.from_messages([
26
  ("system", instructions),
27
  MessagesPlaceholder(variable_name="chat_history"),
28
  ("user", "Question: {input}\nContext: {context}")
29
  ])
30
 
31
- question_answer_chain = create_stuff_documents_chain(llm, custom_prompt)
32
- chain = create_retrieval_chain(vectordb_retriever, question_answer_chain)
 
 
 
33
 
34
- greetings = {"hey", "hi", "hello"}
35
 
36
- def conversate_assistant(query, history, memory):
37
- # Load limited chat history from memory
38
- history_data = memory.load_memory_variables({})["chat_history"]
39
- chat_history = history_data[-6:] if len(history_data) >= 6 else history_data
40
 
 
 
41
  normalized_query = query.strip().lower()
42
 
 
 
 
 
 
 
43
  if normalized_query in greetings:
44
- response = question_answer_chain({
45
  "input": query,
46
- "context": [],
47
  "chat_history": chat_history
48
  })
49
- answer = response.get("output") or str(response)
50
  else:
51
- response = chain({
52
  "input": query,
53
  "chat_history": chat_history
54
  })
55
- answer = response.get('answer') or str(response)
56
-
57
- # Save context in this session's memory
58
  memory.save_context({"input": query}, {"output": answer})
59
 
60
- return answer, history + [(query, answer)], memory
61
-
62
- # Gradio interface with state for memory and chat history
63
- with gr.Blocks() as demo:
64
- chatbot = gr.Chatbot()
65
- msg = gr.Textbox()
66
- state = gr.State([]) # to keep chat history visible in UI
67
- memory_state = gr.State(None) # to keep ConversationBufferMemory per session
68
-
69
- def init_memory():
70
- return ConversationBufferMemory(
71
- memory_key="chat_history",
72
- return_messages=True
73
- )
74
-
75
- def respond(user_message, chat_history, memory):
76
- if memory is None:
77
- memory = init_memory()
78
- answer, chat_history, memory = conversate_assistant(user_message, chat_history, memory)
79
- return chat_history, memory, ""
80
 
81
- msg.submit(respond, inputs=[msg, state, memory_state], outputs=[chatbot, memory_state, msg])
 
 
 
82
 
83
- demo.launch()
 
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
  from langchain.memory import ConversationBufferMemory
6
  from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.memory.chat_message_histories import ChatMessageHistory
8
  from langchain_openai import ChatOpenAI
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
10
  from langchain.embeddings import HuggingFaceEmbeddings
11
 
12
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
13
 
14
+ persist_directory = 'vec_db'
15
 
16
+ vectordb = Chroma(persist_directory=persist_directory,
17
+ embedding_function=embedding_model)
18
 
19
+ vectordb_retriever = vectordb.as_retriever(search_kwargs={'k':5})
20
 
21
+ llm = ChatOpenAI(model="gpt-4.1-nano", temperature=0.7)
22
 
23
+ with open("instructions.txt", 'r') as file:
24
  instructions = file.read()
25
 
26
+
27
+ # Custom prompt
28
  custom_prompt = ChatPromptTemplate.from_messages([
29
  ("system", instructions),
30
  MessagesPlaceholder(variable_name="chat_history"),
31
  ("user", "Question: {input}\nContext: {context}")
32
  ])
33
 
34
+ # Memory
35
+ memory = ConversationBufferMemory(
36
+ memory_key="chat_history",
37
+ return_messages=True
38
+ )
39
 
40
+ question_answer_chain = create_stuff_documents_chain(llm, custom_prompt)
41
 
42
+ chain = create_retrieval_chain(vectordb_retriever, question_answer_chain)
 
 
 
43
 
44
+ def conversate_assistant(query, history):
45
+ greetings = {"hey", "hi", "hello"}
46
  normalized_query = query.strip().lower()
47
 
48
+ if len(memory.load_memory_variables({})["chat_history"]) >=6:
49
+ chat_history = memory.load_memory_variables({})["chat_history"][-6::]
50
+ else:
51
+ chat_history = memory.load_memory_variables({})["chat_history"]
52
+
53
+ # If greeting, skip retrieval and context
54
  if normalized_query in greetings:
55
+ response = question_answer_chain.invoke({
56
  "input": query,
57
+ "context": [], # empty context for greetings
58
  "chat_history": chat_history
59
  })
60
+ answer = response
61
  else:
62
+ response = chain.invoke({
63
  "input": query,
64
  "chat_history": chat_history
65
  })
66
+ answer = response['answer']
67
+
68
+ # Save to memory
69
  memory.save_context({"input": query}, {"output": answer})
70
 
71
+ return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ demo = gr.ChatInterface(
74
+ conversate_assistant,
75
+ type="messages"
76
+ )
77
 
78
+ demo.launch()