Spaces:
Runtime error
Runtime error
Update ragchain.py
Browse files- ragchain.py +55 -3
ragchain.py
CHANGED
|
@@ -1,13 +1,65 @@
|
|
| 1 |
-
class
|
| 2 |
|
| 3 |
def __init__(self, llm, vector_store):
|
| 4 |
"""
|
| 5 |
-
Initialize the RAGChain with an LLM instance and a
|
| 6 |
"""
|
| 7 |
self.llm = llm
|
| 8 |
self.vector_store = vector_store
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def rewrite_query(self, query):
|
| 12 |
"""
|
| 13 |
Rewrite the user's query to align with the language and structure of the library's methods and documentation.
|
|
|
|
| 1 |
+
class RAGBot:
|
| 2 |
|
| 3 |
def __init__(self, llm, vector_store):
|
| 4 |
"""
|
| 5 |
+
Initialize the RAGChain with an LLM instance, a vector store, and a conversation history.
|
| 6 |
"""
|
| 7 |
self.llm = llm
|
| 8 |
self.vector_store = vector_store
|
| 9 |
+
self.conversation = []
|
| 10 |
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def rag_chain(self, query):
|
| 14 |
+
"""
|
| 15 |
+
Process a user query, handle history, retrieve contexts, and generate a response.
|
| 16 |
+
"""
|
| 17 |
+
# Add the user query to the conversation history
|
| 18 |
+
self.add_to_conversation(user_query=query)
|
| 19 |
+
|
| 20 |
+
# Rewrite query
|
| 21 |
+
rewritten_query = self.rewrite_query(query)
|
| 22 |
+
|
| 23 |
+
# Predict library usage
|
| 24 |
+
code_library_usage_prediction = self.predict_library_usage(query)
|
| 25 |
+
|
| 26 |
+
# Retrieve contexts
|
| 27 |
+
doc_contexts = self.retrieve_contexts(query, k=5, filter={"usage": "doc"})
|
| 28 |
+
code_contexts = self.retrieve_contexts(rewritten_query, k=3, filter={"usage": code_library_usage_prediction})
|
| 29 |
+
|
| 30 |
+
# Format contexts
|
| 31 |
+
formatted_doc_contexts = self.format_documents(doc_contexts)
|
| 32 |
+
formatted_code_contexts = self.format_documents(code_contexts)
|
| 33 |
+
|
| 34 |
+
# Generate response
|
| 35 |
+
response = self.generate_response(query, formatted_doc_contexts, formatted_code_contexts)
|
| 36 |
+
|
| 37 |
+
# Add the response to the existing query in the conversation history
|
| 38 |
+
self.add_to_conversation(llm_response=response)
|
| 39 |
+
|
| 40 |
+
return response
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def add_to_conversation(self, user_query=None, llm_response=None):
|
| 44 |
+
"""
|
| 45 |
+
Add either the user's query, the LLM's response, or both to the conversation history.
|
| 46 |
+
"""
|
| 47 |
+
if user_query and llm_response:
|
| 48 |
+
# Add a full query-response pair
|
| 49 |
+
self.conversation.append({"query": user_query, "response": llm_response})
|
| 50 |
+
elif user_query:
|
| 51 |
+
# Add a query with no response yet
|
| 52 |
+
self.conversation.append({"query": user_query, "response": None})
|
| 53 |
+
elif llm_response and self.conversation:
|
| 54 |
+
# Add a response to the most recent query
|
| 55 |
+
self.conversation[-1]["response"] = llm_response
|
| 56 |
+
|
| 57 |
+
def get_history(self):
|
| 58 |
+
"""
|
| 59 |
+
Retrieve the entire conversation history.
|
| 60 |
+
"""
|
| 61 |
+
return self.conversation
|
| 62 |
+
|
| 63 |
def rewrite_query(self, query):
|
| 64 |
"""
|
| 65 |
Rewrite the user's query to align with the language and structure of the library's methods and documentation.
|