SanketAI commited on
Commit
d0840c7
·
verified ·
1 Parent(s): 95964e1

Update agents/qa_agent.py

Browse files
Files changed (1) hide show
  1. agents/qa_agent.py +61 -61
agents/qa_agent.py CHANGED
@@ -1,62 +1,62 @@
1
- import os
2
- import streamlit as st
3
- from agents import SearchAgent
4
- from langchain.vectorstores import FAISS
5
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
- from config.config import model
7
-
8
-
9
-
10
-
11
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
12
-
13
- class QAAgent:
14
- def __init__(self):
15
-
16
- self.model = model
17
- self.prompt = """You are a research assistant answering questions about academic papers. Use the following context from papers and chat history to provide accurate, specific answers.
18
-
19
- Previous conversation:
20
- {chat_history}
21
-
22
- Paper context:
23
- {context}
24
-
25
- Question: {question}
26
-
27
- Guidelines:
28
- 1. Reference specific papers when making claims
29
- 2. Use direct quotes when relevant
30
- 3. Acknowledge if information isn't available in the provided context
31
- 4. Maintain academic tone and precision
32
- """
33
- self.papers = None
34
- self.search_agent_response = ""
35
-
36
- def solve(self, query):
37
- # Check if search has been performed
38
- if not os.path.exists("vector_db"):
39
- st.warning("No papers loaded. Performing search first...")
40
- search_agent = SearchAgent()
41
- self.search_agent_response , self.papers = search_agent.solve(query)
42
-
43
- # Load vector store
44
- vector_db = FAISS.load_local("vector_db", embeddings, index_name="base_and_adjacent", allow_dangerous_deserialization=True)
45
-
46
- # Get chat history
47
- chat_history = st.session_state.get("chat_history", [])
48
- chat_history_text = "".join([f"{sender}: {msg}" for sender, msg in chat_history[-5:]]) # Last 5 messages
49
-
50
- # Get relevant chunks
51
- retrieved = vector_db.as_retriever().get_relevant_documents(query)
52
- context = "".join([f"{doc.page_content}\n Source: {doc.metadata['source']}" for doc in retrieved])
53
-
54
- # Generate response
55
- full_prompt = self.prompt.format(
56
- chat_history=chat_history_text,
57
- context=context,
58
- question=query
59
- )
60
-
61
- response = self.model.generate_content(str(self.search_agent_response) + full_prompt)
62
  return response.text , self.papers
 
1
+ import os
2
+ import streamlit as st
3
+ from agents import SearchAgent
4
+ from langchain.vectorstores import FAISS
5
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
+ from config.config import model
7
+
8
+
9
+
10
+
11
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
12
+
13
+ class QAAgent:
14
+ def __init__(self):
15
+
16
+ self.model = model
17
+ self.prompt = """You are a research assistant answering questions about academic papers. Use the following context from papers and chat history to provide accurate, specific answers.
18
+
19
+ Previous conversation:
20
+ {chat_history}
21
+
22
+ Paper context:
23
+ {context}
24
+
25
+ Question: {question}
26
+
27
+ Guidelines:
28
+ 1. Reference specific papers when making claims
29
+ 2. Use direct quotes when relevant
30
+ 3. Acknowledge if information isn't available in the provided context
31
+ 4. Maintain academic tone and precision
32
+ """
33
+ self.papers = None
34
+ self.search_agent_response = ""
35
+
36
+ def solve(self, query):
37
+ # Check if search has been performed
38
+ if not os.path.exists("vector_db"):
39
+ st.warning("No papers loaded. Performing search first...")
40
+ search_agent = SearchAgent()
41
+ self.search_agent_response , self.papers = search_agent.solve(query)
42
+
43
+ # Load vector store
44
+ vector_db = FAISS.load_local("agents/vector_db", embeddings, index_name="base_and_adjacent", allow_dangerous_deserialization=True)
45
+
46
+ # Get chat history
47
+ chat_history = st.session_state.get("chat_history", [])
48
+ chat_history_text = "".join([f"{sender}: {msg}" for sender, msg in chat_history[-5:]]) # Last 5 messages
49
+
50
+ # Get relevant chunks
51
+ retrieved = vector_db.as_retriever().get_relevant_documents(query)
52
+ context = "".join([f"{doc.page_content}\n Source: {doc.metadata['source']}" for doc in retrieved])
53
+
54
+ # Generate response
55
+ full_prompt = self.prompt.format(
56
+ chat_history=chat_history_text,
57
+ context=context,
58
+ question=query
59
+ )
60
+
61
+ response = self.model.generate_content(str(self.search_agent_response) + full_prompt)
62
  return response.text , self.papers