Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from qdrant_client.http.models import Distance, VectorParams
|
|
| 15 |
from qdrant_client.models import PointIdsList
|
| 16 |
|
| 17 |
from langgraph.graph import MessagesState, StateGraph
|
| 18 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
| 19 |
from langgraph.prebuilt import ToolNode
|
| 20 |
from langgraph.graph import END
|
| 21 |
from langgraph.prebuilt import tools_condition
|
|
@@ -114,52 +114,76 @@ class QASystem:
|
|
| 114 |
|
| 115 |
graph_builder = StateGraph(MessagesState)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
|
|
|
| 139 |
def generate(state: MessagesState):
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
system_prompt = (
|
| 145 |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
|
| 146 |
-
"
|
| 147 |
-
|
|
|
|
| 148 |
)
|
| 149 |
|
| 150 |
-
messages
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
response = llm.invoke(messages)
|
| 155 |
-
return {"messages": [response]}
|
| 156 |
-
|
| 157 |
|
| 158 |
-
|
|
|
|
| 159 |
graph_builder.add_node("generate", generate)
|
| 160 |
|
| 161 |
-
|
| 162 |
-
graph_builder.
|
|
|
|
| 163 |
graph_builder.add_edge("generate", END)
|
| 164 |
|
| 165 |
self.memory = MemorySaver()
|
|
@@ -173,16 +197,25 @@ class QASystem:
|
|
| 173 |
def process_query(self, query: str) -> List[Dict[str, str]]:
|
| 174 |
try:
|
| 175 |
responses = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
for step in self.graph.stream(
|
| 177 |
{"messages": [HumanMessage(content=query)]},
|
| 178 |
stream_mode="values",
|
| 179 |
-
config={"configurable": {"thread_id":
|
| 180 |
):
|
| 181 |
if step["messages"]:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return responses
|
| 187 |
except Exception as e:
|
| 188 |
logger.error(f"Query processing error: {str(e)}")
|
|
@@ -197,4 +230,4 @@ else:
|
|
| 197 |
@app.post("/query")
|
| 198 |
async def query_api(query: str):
|
| 199 |
responses = qa_system.process_query(query)
|
| 200 |
-
return {"responses": responses}
|
|
|
|
| 15 |
from qdrant_client.models import PointIdsList
|
| 16 |
|
| 17 |
from langgraph.graph import MessagesState, StateGraph
|
| 18 |
+
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
| 19 |
from langgraph.prebuilt import ToolNode
|
| 20 |
from langgraph.graph import END
|
| 21 |
from langgraph.prebuilt import tools_condition
|
|
|
|
| 114 |
|
| 115 |
graph_builder = StateGraph(MessagesState)
|
| 116 |
|
| 117 |
+
# Define a retrieval node that fetches relevant docs
|
| 118 |
+
def retrieve_docs(state: MessagesState):
|
| 119 |
+
# Get the most recent human message
|
| 120 |
+
human_messages = [m for m in state["messages"] if m.type == "human"]
|
| 121 |
+
if not human_messages:
|
| 122 |
+
return {"messages": state["messages"]}
|
| 123 |
+
|
| 124 |
+
user_query = human_messages[-1].content
|
| 125 |
+
logger.info(f"Retrieving documents for query: {user_query}")
|
| 126 |
+
|
| 127 |
+
# Query the vector store
|
| 128 |
+
try:
|
| 129 |
+
retrieved_docs = self.vector_store.similarity_search(user_query, k=3)
|
| 130 |
+
|
| 131 |
+
# Create tool messages for each retrieved document
|
| 132 |
+
tool_messages = []
|
| 133 |
+
for i, doc in enumerate(retrieved_docs):
|
| 134 |
+
tool_messages.append(
|
| 135 |
+
ToolMessage(
|
| 136 |
+
content=f"Document {i+1}: {doc.page_content}",
|
| 137 |
+
tool_call_id=f"retrieval_{i}"
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
logger.info(f"Retrieved {len(tool_messages)} relevant documents")
|
| 142 |
+
return {"messages": state["messages"] + tool_messages}
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Error retrieving documents: {str(e)}")
|
| 146 |
+
return {"messages": state["messages"]}
|
| 147 |
|
| 148 |
+
# Updated generate function that uses retrieved documents
|
| 149 |
def generate(state: MessagesState):
|
| 150 |
+
# Extract retrieved documents (tool messages)
|
| 151 |
+
tool_messages = [m for m in state["messages"] if m.type == "tool"]
|
| 152 |
+
|
| 153 |
+
# Collect context from retrieved documents
|
| 154 |
+
if tool_messages:
|
| 155 |
+
context = "\n".join([m.content for m in tool_messages])
|
| 156 |
+
logger.info(f"Using context from {len(tool_messages)} retrieved documents")
|
| 157 |
+
else:
|
| 158 |
+
context = "No specific mountain bicycle documentation available."
|
| 159 |
+
logger.info("No relevant documents retrieved, using default context")
|
| 160 |
|
| 161 |
system_prompt = (
|
| 162 |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
|
| 163 |
+
"Always provide accurate responses with references to provided data. "
|
| 164 |
+
"If the user query is not technical-specific, still respond from a IETM perspective."
|
| 165 |
+
f"\n\nContext from mountain bicycle documentation:\n{context}"
|
| 166 |
)
|
| 167 |
|
| 168 |
+
# Get all messages excluding tool messages to avoid redundancy
|
| 169 |
+
human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"]
|
| 170 |
+
|
| 171 |
+
# Create the full message history for the LLM
|
| 172 |
+
messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages
|
| 173 |
+
|
| 174 |
+
logger.info(f"Sending query to LLM with {len(messages)} messages")
|
| 175 |
+
|
| 176 |
+
# Generate the response
|
| 177 |
response = llm.invoke(messages)
|
| 178 |
+
return {"messages": state["messages"] + [response]}
|
|
|
|
| 179 |
|
| 180 |
+
# Add nodes to the graph
|
| 181 |
+
graph_builder.add_node("retrieve_docs", retrieve_docs)
|
| 182 |
graph_builder.add_node("generate", generate)
|
| 183 |
|
| 184 |
+
# Set the flow of the graph
|
| 185 |
+
graph_builder.set_entry_point("retrieve_docs")
|
| 186 |
+
graph_builder.add_edge("retrieve_docs", "generate")
|
| 187 |
graph_builder.add_edge("generate", END)
|
| 188 |
|
| 189 |
self.memory = MemorySaver()
|
|
|
|
| 197 |
def process_query(self, query: str) -> List[Dict[str, str]]:
|
| 198 |
try:
|
| 199 |
responses = []
|
| 200 |
+
|
| 201 |
+
# Use a unique thread_id for each conversation
|
| 202 |
+
thread_id = "abc123" # In production, generate a unique ID for each conversation
|
| 203 |
+
|
| 204 |
+
# Stream the responses
|
| 205 |
for step in self.graph.stream(
|
| 206 |
{"messages": [HumanMessage(content=query)]},
|
| 207 |
stream_mode="values",
|
| 208 |
+
config={"configurable": {"thread_id": thread_id}}
|
| 209 |
):
|
| 210 |
if step["messages"]:
|
| 211 |
+
# Only include AI messages in the response
|
| 212 |
+
ai_messages = [m for m in step["messages"] if m.type == "ai"]
|
| 213 |
+
if ai_messages:
|
| 214 |
+
responses.append({
|
| 215 |
+
'content': ai_messages[-1].content,
|
| 216 |
+
'type': ai_messages[-1].type
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
return responses
|
| 220 |
except Exception as e:
|
| 221 |
logger.error(f"Query processing error: {str(e)}")
|
|
|
|
| 230 |
@app.post("/query")
|
| 231 |
async def query_api(query: str):
|
| 232 |
responses = qa_system.process_query(query)
|
| 233 |
+
return {"responses": responses}
|