Spaces:
Sleeping
Sleeping
Commit ·
9e67c92
1
Parent(s): 975d8af
enhanced_query
Browse files- generate.py +0 -8
- langgraph_agent.py +22 -16
- requirements.txt +3 -0
generate.py
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
from google import genai
|
| 2 |
-
from dotenv import load_dotenv
|
| 3 |
-
from os import getenv
|
| 4 |
-
|
| 5 |
-
load_dotenv()
|
| 6 |
-
|
| 7 |
-
GEMINI_API_KEY = getenv("GEMINI_API_KEY")
|
| 8 |
-
from .constants import GEMINI_API_KEY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
langgraph_agent.py
CHANGED
|
@@ -12,21 +12,6 @@ class AgentState(TypedDict):
|
|
| 12 |
context: List[str]
|
| 13 |
response: str
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
# def query_classifier(state: AgentState) -> AgentState:
|
| 18 |
-
# """Determine if the query requires RAG retrieval based on keywords.
|
| 19 |
-
# Is not continued anymore, will be removed in future."""
|
| 20 |
-
# query_lower = state["query"].lower()
|
| 21 |
-
# rag_keywords = [
|
| 22 |
-
# "scheme", "schemes", "program", "programs", "policy", "policies",
|
| 23 |
-
# "public health engineering", "phe", "public health", "government",
|
| 24 |
-
# "benefit", "financial", "assistance", "aid", "initiative", "yojana",
|
| 25 |
-
# ]
|
| 26 |
-
|
| 27 |
-
# state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
|
| 28 |
-
# return state
|
| 29 |
-
|
| 30 |
def query_classifier(state: AgentState) -> AgentState:
|
| 31 |
"""Updated classifier to use LLM for intent classification."""
|
| 32 |
|
|
@@ -44,6 +29,25 @@ def query_classifier(state: AgentState) -> AgentState:
|
|
| 44 |
state["requires_rag"] = "yes" in result.lower()
|
| 45 |
return state
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def retrieve_documents(state: AgentState) -> AgentState:
|
| 48 |
"""Retrieve documents from vector store if needed."""
|
| 49 |
if state["requires_rag"]:
|
|
@@ -121,17 +125,19 @@ def create_agent_workflow():
|
|
| 121 |
workflow = StateGraph(AgentState)
|
| 122 |
|
| 123 |
# Add nodes
|
|
|
|
| 124 |
workflow.add_node("classifier", query_classifier)
|
| 125 |
workflow.add_node("retriever", retrieve_documents)
|
| 126 |
workflow.add_node("responder", generate_response)
|
| 127 |
|
| 128 |
# Create edges
|
|
|
|
| 129 |
workflow.add_edge("classifier", "retriever")
|
| 130 |
workflow.add_edge("retriever", "responder")
|
| 131 |
workflow.add_edge("responder", END)
|
| 132 |
|
| 133 |
# Set the entry point
|
| 134 |
-
workflow.set_entry_point("
|
| 135 |
|
| 136 |
# Compile the graph
|
| 137 |
return workflow.compile()
|
|
|
|
| 12 |
context: List[str]
|
| 13 |
response: str
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def query_classifier(state: AgentState) -> AgentState:
|
| 16 |
"""Updated classifier to use LLM for intent classification."""
|
| 17 |
|
|
|
|
| 29 |
state["requires_rag"] = "yes" in result.lower()
|
| 30 |
return state
|
| 31 |
|
| 32 |
+
def enhance_query(state:AgentState) -> AgentState:
|
| 33 |
+
"""Enhance the query with user data and context."""
|
| 34 |
+
previous_conversation = state.get("previous_conversation", "")
|
| 35 |
+
user_data = state.get("user_data", {})
|
| 36 |
+
query = state.get("query", "")
|
| 37 |
+
|
| 38 |
+
query_enhancement_prompt = f"""
|
| 39 |
+
Enhance the following query with user data and previous conversation context so it uses the previous conversation and user data.
|
| 40 |
+
To be used for generating a more relevant and personalized response.
|
| 41 |
+
Previous Conversation: {previous_conversation}
|
| 42 |
+
User Data: {user_data}
|
| 43 |
+
Current Query: {query}
|
| 44 |
+
Only write the enhanced query. No other text."""
|
| 45 |
+
result = llm.predict(query_enhancement_prompt)
|
| 46 |
+
print("Enhanced query: ", result)
|
| 47 |
+
state["query"] = result
|
| 48 |
+
|
| 49 |
+
return state
|
| 50 |
+
|
| 51 |
def retrieve_documents(state: AgentState) -> AgentState:
|
| 52 |
"""Retrieve documents from vector store if needed."""
|
| 53 |
if state["requires_rag"]:
|
|
|
|
| 125 |
workflow = StateGraph(AgentState)
|
| 126 |
|
| 127 |
# Add nodes
|
| 128 |
+
workflow.add_node("enhance_query", enhance_query)
|
| 129 |
workflow.add_node("classifier", query_classifier)
|
| 130 |
workflow.add_node("retriever", retrieve_documents)
|
| 131 |
workflow.add_node("responder", generate_response)
|
| 132 |
|
| 133 |
# Create edges
|
| 134 |
+
workflow.add_edge("enhance_query", "classifier")
|
| 135 |
workflow.add_edge("classifier", "retriever")
|
| 136 |
workflow.add_edge("retriever", "responder")
|
| 137 |
workflow.add_edge("responder", END)
|
| 138 |
|
| 139 |
# Set the entry point
|
| 140 |
+
workflow.set_entry_point("enhance_query")
|
| 141 |
|
| 142 |
# Compile the graph
|
| 143 |
return workflow.compile()
|
requirements.txt
CHANGED
|
@@ -273,3 +273,6 @@ xxhash==3.5.0
|
|
| 273 |
yarl==1.18.3
|
| 274 |
zipp==3.21.0
|
| 275 |
zstandard==0.23.0
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
yarl==1.18.3
|
| 274 |
zipp==3.21.0
|
| 275 |
zstandard==0.23.0
|
| 276 |
+
tf-keras
|
| 277 |
+
# pathway[all]
|
| 278 |
+
# langchain-google-genai
|