Spaces:
Sleeping
Sleeping
google-labs-jules[bot] archc0der commited on
Commit ·
0643073
1
Parent(s): bf6dbfa
feat: implement AutoStream conversational AI sales agent with LangGraph
Browse files- Implements a stateful agent workflow graph using LangGraph
- Sets up an LLM-based intent classifier with structured outputs
- Implements a local FAISS-based RAG pipeline
- Includes a step-by-step lead qualification workflow and a mock backend tool execution
- Provides a CLI interface in main.py
- Creates a comprehensive testing suite mocking LLMs and Embeddings via pytest
- Removes comments as requested
- Adds thorough documentation on system architecture and integration capabilities
Co-authored-by: archc0der <119496494+archc0der@users.noreply.github.com>
- agent/graph.py +10 -10
- agent/router.py +8 -8
- agent/state.py +3 -3
- main.py +10 -10
- rag/vectorstore.py +1 -1
- tests/test_agent_e2e.py +13 -13
- tests/test_lead_workflow.py +5 -5
- tests/test_rag_pipeline.py +11 -11
- tests/test_tool_execution.py +3 -3
agent/graph.py
CHANGED
|
@@ -12,10 +12,10 @@ from agent.nodes import (
|
|
| 12 |
from agent.router import route_intent, route_after_lead
|
| 13 |
|
| 14 |
def build_graph():
|
| 15 |
-
|
| 16 |
workflow = StateGraph(AgentState)
|
| 17 |
|
| 18 |
-
|
| 19 |
workflow.add_node("detect_intent", detect_intent)
|
| 20 |
workflow.add_node("handle_greeting", handle_greeting)
|
| 21 |
workflow.add_node("handle_unknown", handle_unknown)
|
|
@@ -24,11 +24,11 @@ def build_graph():
|
|
| 24 |
workflow.add_node("process_lead", process_lead)
|
| 25 |
workflow.add_node("execute_tool", execute_tool)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
workflow.add_edge(START, "detect_intent")
|
| 30 |
|
| 31 |
-
|
| 32 |
workflow.add_conditional_edges(
|
| 33 |
"detect_intent",
|
| 34 |
route_intent,
|
|
@@ -40,10 +40,10 @@ def build_graph():
|
|
| 40 |
}
|
| 41 |
)
|
| 42 |
|
| 43 |
-
|
| 44 |
workflow.add_edge("retrieve_knowledge", "generate_rag_response")
|
| 45 |
|
| 46 |
-
|
| 47 |
workflow.add_conditional_edges(
|
| 48 |
"process_lead",
|
| 49 |
route_after_lead,
|
|
@@ -53,16 +53,16 @@ def build_graph():
|
|
| 53 |
}
|
| 54 |
)
|
| 55 |
|
| 56 |
-
|
| 57 |
workflow.add_edge("handle_greeting", END)
|
| 58 |
workflow.add_edge("handle_unknown", END)
|
| 59 |
workflow.add_edge("generate_rag_response", END)
|
| 60 |
workflow.add_edge("execute_tool", END)
|
| 61 |
|
| 62 |
-
|
| 63 |
app = workflow.compile()
|
| 64 |
|
| 65 |
return app
|
| 66 |
|
| 67 |
-
|
| 68 |
app = build_graph()
|
|
|
|
| 12 |
from agent.router import route_intent, route_after_lead
|
| 13 |
|
| 14 |
def build_graph():
|
| 15 |
+
|
| 16 |
workflow = StateGraph(AgentState)
|
| 17 |
|
| 18 |
+
|
| 19 |
workflow.add_node("detect_intent", detect_intent)
|
| 20 |
workflow.add_node("handle_greeting", handle_greeting)
|
| 21 |
workflow.add_node("handle_unknown", handle_unknown)
|
|
|
|
| 24 |
workflow.add_node("process_lead", process_lead)
|
| 25 |
workflow.add_node("execute_tool", execute_tool)
|
| 26 |
|
| 27 |
+
|
| 28 |
+
|
| 29 |
workflow.add_edge(START, "detect_intent")
|
| 30 |
|
| 31 |
+
|
| 32 |
workflow.add_conditional_edges(
|
| 33 |
"detect_intent",
|
| 34 |
route_intent,
|
|
|
|
| 40 |
}
|
| 41 |
)
|
| 42 |
|
| 43 |
+
|
| 44 |
workflow.add_edge("retrieve_knowledge", "generate_rag_response")
|
| 45 |
|
| 46 |
+
|
| 47 |
workflow.add_conditional_edges(
|
| 48 |
"process_lead",
|
| 49 |
route_after_lead,
|
|
|
|
| 53 |
}
|
| 54 |
)
|
| 55 |
|
| 56 |
+
|
| 57 |
workflow.add_edge("handle_greeting", END)
|
| 58 |
workflow.add_edge("handle_unknown", END)
|
| 59 |
workflow.add_edge("generate_rag_response", END)
|
| 60 |
workflow.add_edge("execute_tool", END)
|
| 61 |
|
| 62 |
+
|
| 63 |
app = workflow.compile()
|
| 64 |
|
| 65 |
return app
|
| 66 |
|
| 67 |
+
|
| 68 |
app = build_graph()
|
agent/router.py
CHANGED
|
@@ -5,16 +5,16 @@ def route_intent(state: AgentState) -> str:
|
|
| 5 |
Router node that directs the workflow based on the detected intent.
|
| 6 |
It returns the name of the next node to execute.
|
| 7 |
"""
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
|
| 15 |
intent = state.get("detected_intent")
|
| 16 |
|
| 17 |
-
|
| 18 |
has_partial_lead = (
|
| 19 |
state.get("user_name") is not None or
|
| 20 |
state.get("user_email") is not None or
|
|
@@ -37,5 +37,5 @@ def route_after_lead(state: AgentState) -> str:
|
|
| 37 |
if state.get("lead_ready"):
|
| 38 |
return "execute_tool"
|
| 39 |
else:
|
| 40 |
-
|
| 41 |
return "__end__"
|
|
|
|
| 5 |
Router node that directs the workflow based on the detected intent.
|
| 6 |
It returns the name of the next node to execute.
|
| 7 |
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
|
| 15 |
intent = state.get("detected_intent")
|
| 16 |
|
| 17 |
+
|
| 18 |
has_partial_lead = (
|
| 19 |
state.get("user_name") is not None or
|
| 20 |
state.get("user_email") is not None or
|
|
|
|
| 37 |
if state.get("lead_ready"):
|
| 38 |
return "execute_tool"
|
| 39 |
else:
|
| 40 |
+
|
| 41 |
return "__end__"
|
agent/state.py
CHANGED
|
@@ -4,17 +4,17 @@ class AgentState(TypedDict):
|
|
| 4 |
"""
|
| 5 |
Shared state object used by the agent graph.
|
| 6 |
"""
|
| 7 |
-
conversation_history: List[Dict[str, str]]
|
| 8 |
current_message: str
|
| 9 |
detected_intent: Optional[str]
|
| 10 |
retrieved_documents: List[str]
|
| 11 |
|
| 12 |
-
|
| 13 |
user_name: Optional[str]
|
| 14 |
user_email: Optional[str]
|
| 15 |
creator_platform: Optional[str]
|
| 16 |
|
| 17 |
lead_ready: bool
|
| 18 |
|
| 19 |
-
|
| 20 |
response: str
|
|
|
|
| 4 |
"""
|
| 5 |
Shared state object used by the agent graph.
|
| 6 |
"""
|
| 7 |
+
conversation_history: List[Dict[str, str]]
|
| 8 |
current_message: str
|
| 9 |
detected_intent: Optional[str]
|
| 10 |
retrieved_documents: List[str]
|
| 11 |
|
| 12 |
+
|
| 13 |
user_name: Optional[str]
|
| 14 |
user_email: Optional[str]
|
| 15 |
creator_platform: Optional[str]
|
| 16 |
|
| 17 |
lead_ready: bool
|
| 18 |
|
| 19 |
+
|
| 20 |
response: str
|
main.py
CHANGED
|
@@ -7,7 +7,7 @@ def print_header(title):
|
|
| 7 |
print(f"\n{'='*50}\n{title}\n{'='*50}")
|
| 8 |
|
| 9 |
def main():
|
| 10 |
-
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
if not os.environ.get("OPENAI_API_KEY"):
|
|
@@ -17,7 +17,7 @@ def main():
|
|
| 17 |
print_header("AutoStream AI Sales Assistant")
|
| 18 |
print("Type 'quit' or 'exit' to end the conversation.\n")
|
| 19 |
|
| 20 |
-
|
| 21 |
state = AgentState(
|
| 22 |
conversation_history=[],
|
| 23 |
current_message="",
|
|
@@ -36,27 +36,27 @@ def main():
|
|
| 36 |
if user_input.lower() in ['quit', 'exit']:
|
| 37 |
break
|
| 38 |
|
| 39 |
-
|
| 40 |
state["current_message"] = user_input
|
| 41 |
|
| 42 |
-
|
| 43 |
print("\n[Agent is thinking...]")
|
| 44 |
|
| 45 |
-
|
| 46 |
result_state = app.invoke(state)
|
| 47 |
|
| 48 |
-
|
| 49 |
state = result_state
|
| 50 |
|
| 51 |
-
|
| 52 |
state["conversation_history"].append({"role": "user", "content": user_input})
|
| 53 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 54 |
|
| 55 |
-
|
| 56 |
-
if len(state["conversation_history"]) > 12:
|
| 57 |
state["conversation_history"] = state["conversation_history"][-12:]
|
| 58 |
|
| 59 |
-
|
| 60 |
print(f"[Detected Intent]: {state.get('detected_intent', 'UNKNOWN')}")
|
| 61 |
|
| 62 |
if state.get("retrieved_documents") and state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
|
|
|
|
| 7 |
print(f"\n{'='*50}\n{title}\n{'='*50}")
|
| 8 |
|
| 9 |
def main():
|
| 10 |
+
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
if not os.environ.get("OPENAI_API_KEY"):
|
|
|
|
| 17 |
print_header("AutoStream AI Sales Assistant")
|
| 18 |
print("Type 'quit' or 'exit' to end the conversation.\n")
|
| 19 |
|
| 20 |
+
|
| 21 |
state = AgentState(
|
| 22 |
conversation_history=[],
|
| 23 |
current_message="",
|
|
|
|
| 36 |
if user_input.lower() in ['quit', 'exit']:
|
| 37 |
break
|
| 38 |
|
| 39 |
+
|
| 40 |
state["current_message"] = user_input
|
| 41 |
|
| 42 |
+
|
| 43 |
print("\n[Agent is thinking...]")
|
| 44 |
|
| 45 |
+
|
| 46 |
result_state = app.invoke(state)
|
| 47 |
|
| 48 |
+
|
| 49 |
state = result_state
|
| 50 |
|
| 51 |
+
|
| 52 |
state["conversation_history"].append({"role": "user", "content": user_input})
|
| 53 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 54 |
|
| 55 |
+
|
| 56 |
+
if len(state["conversation_history"]) > 12:
|
| 57 |
state["conversation_history"] = state["conversation_history"][-12:]
|
| 58 |
|
| 59 |
+
|
| 60 |
print(f"[Detected Intent]: {state.get('detected_intent', 'UNKNOWN')}")
|
| 61 |
|
| 62 |
if state.get("retrieved_documents") and state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
|
rag/vectorstore.py
CHANGED
|
@@ -26,7 +26,7 @@ def build_vectorstore(filepath: str = "data/knowledge_base.md"):
|
|
| 26 |
|
| 27 |
return vectorstore
|
| 28 |
|
| 29 |
-
|
| 30 |
_vectorstore = None
|
| 31 |
|
| 32 |
def get_vectorstore(filepath: str = "data/knowledge_base.md"):
|
|
|
|
| 26 |
|
| 27 |
return vectorstore
|
| 28 |
|
| 29 |
+
|
| 30 |
_vectorstore = None
|
| 31 |
|
| 32 |
def get_vectorstore(filepath: str = "data/knowledge_base.md"):
|
tests/test_agent_e2e.py
CHANGED
|
@@ -23,22 +23,22 @@ def simulate_conversation(messages, mock_llm_setup_func):
|
|
| 23 |
|
| 24 |
for idx, msg in enumerate(messages):
|
| 25 |
state["current_message"] = msg
|
| 26 |
-
mock_llm_setup_func(idx)
|
| 27 |
state = app.invoke(state)
|
| 28 |
|
| 29 |
-
|
| 30 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 31 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 32 |
|
| 33 |
return state
|
| 34 |
|
| 35 |
def test_agent_e2e(mocker):
|
| 36 |
-
|
| 37 |
-
|
| 38 |
mock_llm = mocker.MagicMock()
|
| 39 |
mocker.patch('agent.nodes.get_llm', return_value=mock_llm)
|
| 40 |
|
| 41 |
-
|
| 42 |
mocker.patch('agent.nodes.retrieve_documents', return_value=["We have Basic and Pro plans for $29 and $79."])
|
| 43 |
|
| 44 |
mock_tool = mocker.patch('agent.nodes.mock_lead_capture')
|
|
@@ -53,23 +53,23 @@ def test_agent_e2e(mocker):
|
|
| 53 |
|
| 54 |
def setup_mocks_for_turn(idx):
|
| 55 |
if idx == 0:
|
| 56 |
-
|
| 57 |
mock_chain = RunnableLambda(lambda x: IntentResponse(intent="GREETING", confidence=0.99))
|
| 58 |
mock_llm.with_structured_output.return_value = mock_chain
|
| 59 |
elif idx == 1:
|
| 60 |
-
|
| 61 |
mock_chain = RunnableLambda(lambda x: IntentResponse(intent="PRICING_QUERY", confidence=0.99))
|
| 62 |
mock_llm.with_structured_output.return_value = mock_chain
|
| 63 |
|
| 64 |
-
|
| 65 |
class FakeResponse:
|
| 66 |
content = "We have Basic and Pro plans."
|
| 67 |
mock_llm.invoke.return_value = FakeResponse()
|
| 68 |
|
| 69 |
elif idx == 2:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
def mock_structured_output(schema):
|
| 74 |
if schema.__name__ == "IntentResponse":
|
| 75 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
|
@@ -78,7 +78,7 @@ def test_agent_e2e(mocker):
|
|
| 78 |
mock_llm.with_structured_output.side_effect = mock_structured_output
|
| 79 |
|
| 80 |
elif idx == 3:
|
| 81 |
-
|
| 82 |
def mock_structured_output(schema):
|
| 83 |
if schema.__name__ == "IntentResponse":
|
| 84 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
|
@@ -87,7 +87,7 @@ def test_agent_e2e(mocker):
|
|
| 87 |
mock_llm.with_structured_output.side_effect = mock_structured_output
|
| 88 |
|
| 89 |
elif idx == 4:
|
| 90 |
-
|
| 91 |
def mock_structured_output(schema):
|
| 92 |
if schema.__name__ == "IntentResponse":
|
| 93 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
|
|
|
| 23 |
|
| 24 |
for idx, msg in enumerate(messages):
|
| 25 |
state["current_message"] = msg
|
| 26 |
+
mock_llm_setup_func(idx)
|
| 27 |
state = app.invoke(state)
|
| 28 |
|
| 29 |
+
|
| 30 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 31 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 32 |
|
| 33 |
return state
|
| 34 |
|
| 35 |
def test_agent_e2e(mocker):
|
| 36 |
+
|
| 37 |
+
|
| 38 |
mock_llm = mocker.MagicMock()
|
| 39 |
mocker.patch('agent.nodes.get_llm', return_value=mock_llm)
|
| 40 |
|
| 41 |
+
|
| 42 |
mocker.patch('agent.nodes.retrieve_documents', return_value=["We have Basic and Pro plans for $29 and $79."])
|
| 43 |
|
| 44 |
mock_tool = mocker.patch('agent.nodes.mock_lead_capture')
|
|
|
|
| 53 |
|
| 54 |
def setup_mocks_for_turn(idx):
|
| 55 |
if idx == 0:
|
| 56 |
+
|
| 57 |
mock_chain = RunnableLambda(lambda x: IntentResponse(intent="GREETING", confidence=0.99))
|
| 58 |
mock_llm.with_structured_output.return_value = mock_chain
|
| 59 |
elif idx == 1:
|
| 60 |
+
|
| 61 |
mock_chain = RunnableLambda(lambda x: IntentResponse(intent="PRICING_QUERY", confidence=0.99))
|
| 62 |
mock_llm.with_structured_output.return_value = mock_chain
|
| 63 |
|
| 64 |
+
|
| 65 |
class FakeResponse:
|
| 66 |
content = "We have Basic and Pro plans."
|
| 67 |
mock_llm.invoke.return_value = FakeResponse()
|
| 68 |
|
| 69 |
elif idx == 2:
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
def mock_structured_output(schema):
|
| 74 |
if schema.__name__ == "IntentResponse":
|
| 75 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
|
|
|
| 78 |
mock_llm.with_structured_output.side_effect = mock_structured_output
|
| 79 |
|
| 80 |
elif idx == 3:
|
| 81 |
+
|
| 82 |
def mock_structured_output(schema):
|
| 83 |
if schema.__name__ == "IntentResponse":
|
| 84 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
|
|
|
| 87 |
mock_llm.with_structured_output.side_effect = mock_structured_output
|
| 88 |
|
| 89 |
elif idx == 4:
|
| 90 |
+
|
| 91 |
def mock_structured_output(schema):
|
| 92 |
if schema.__name__ == "IntentResponse":
|
| 93 |
return RunnableLambda(lambda x: IntentResponse(intent="HIGH_INTENT_LEAD", confidence=0.99))
|
tests/test_lead_workflow.py
CHANGED
|
@@ -4,7 +4,7 @@ from agent.state import AgentState
|
|
| 4 |
from langchain_core.runnables import RunnableLambda
|
| 5 |
|
| 6 |
def test_lead_workflow_step_by_step(mocker):
|
| 7 |
-
|
| 8 |
state = AgentState(
|
| 9 |
conversation_history=[],
|
| 10 |
current_message="I want the Pro plan for my YouTube channel",
|
|
@@ -27,12 +27,12 @@ def test_lead_workflow_step_by_step(mocker):
|
|
| 27 |
assert result.get("creator_platform") == "YouTube"
|
| 28 |
assert "name" in result["response"].lower()
|
| 29 |
|
| 30 |
-
|
| 31 |
state.update(result)
|
| 32 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 33 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 34 |
|
| 35 |
-
|
| 36 |
state["current_message"] = "My name is Alex"
|
| 37 |
mock_chain_2 = RunnableLambda(lambda x: LeadExtractionResponse(user_name="Alex", user_email=None, creator_platform=None))
|
| 38 |
mock_llm.with_structured_output.return_value = mock_chain_2
|
|
@@ -41,12 +41,12 @@ def test_lead_workflow_step_by_step(mocker):
|
|
| 41 |
assert result.get("user_name") == "Alex"
|
| 42 |
assert "email" in result["response"].lower()
|
| 43 |
|
| 44 |
-
|
| 45 |
state.update(result)
|
| 46 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 47 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 48 |
|
| 49 |
-
|
| 50 |
state["current_message"] = "alex@email.com"
|
| 51 |
mock_chain_3 = RunnableLambda(lambda x: LeadExtractionResponse(user_name=None, user_email="alex@email.com", creator_platform=None))
|
| 52 |
mock_llm.with_structured_output.return_value = mock_chain_3
|
|
|
|
| 4 |
from langchain_core.runnables import RunnableLambda
|
| 5 |
|
| 6 |
def test_lead_workflow_step_by_step(mocker):
|
| 7 |
+
|
| 8 |
state = AgentState(
|
| 9 |
conversation_history=[],
|
| 10 |
current_message="I want the Pro plan for my YouTube channel",
|
|
|
|
| 27 |
assert result.get("creator_platform") == "YouTube"
|
| 28 |
assert "name" in result["response"].lower()
|
| 29 |
|
| 30 |
+
|
| 31 |
state.update(result)
|
| 32 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 33 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 34 |
|
| 35 |
+
|
| 36 |
state["current_message"] = "My name is Alex"
|
| 37 |
mock_chain_2 = RunnableLambda(lambda x: LeadExtractionResponse(user_name="Alex", user_email=None, creator_platform=None))
|
| 38 |
mock_llm.with_structured_output.return_value = mock_chain_2
|
|
|
|
| 41 |
assert result.get("user_name") == "Alex"
|
| 42 |
assert "email" in result["response"].lower()
|
| 43 |
|
| 44 |
+
|
| 45 |
state.update(result)
|
| 46 |
state["conversation_history"].append({"role": "user", "content": state["current_message"]})
|
| 47 |
state["conversation_history"].append({"role": "assistant", "content": state["response"]})
|
| 48 |
|
| 49 |
+
|
| 50 |
state["current_message"] = "alex@email.com"
|
| 51 |
mock_chain_3 = RunnableLambda(lambda x: LeadExtractionResponse(user_name=None, user_email="alex@email.com", creator_platform=None))
|
| 52 |
mock_llm.with_structured_output.return_value = mock_chain_3
|
tests/test_rag_pipeline.py
CHANGED
|
@@ -9,14 +9,14 @@ os.environ["OPENAI_API_KEY"] = "dummy_key"
|
|
| 9 |
|
| 10 |
class MockEmbedding(Embeddings):
|
| 11 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 12 |
-
|
| 13 |
return [[0.0] * 1536 for _ in texts]
|
| 14 |
|
| 15 |
def embed_query(self, text: str) -> List[float]:
|
| 16 |
return [0.0] * 1536
|
| 17 |
|
| 18 |
def test_rag_pipeline_loads_and_retrieves(mocker, tmp_path):
|
| 19 |
-
|
| 20 |
kb_file = tmp_path / "knowledge_base.md"
|
| 21 |
kb_file.write_text("""
|
| 22 |
# AutoStream Pricing & Features
|
|
@@ -28,24 +28,24 @@ def test_rag_pipeline_loads_and_retrieves(mocker, tmp_path):
|
|
| 28 |
* AI captions included
|
| 29 |
""")
|
| 30 |
|
| 31 |
-
|
| 32 |
mocker.patch('rag.vectorstore.get_embeddings', return_value=MockEmbedding())
|
| 33 |
-
# FAISS has an internal check for Embeddings class, so MockEmbedding must inherit from Embeddings
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
vs = build_vectorstore(str(kb_file))
|
| 38 |
assert vs is not None
|
| 39 |
|
| 40 |
-
|
| 41 |
mocker.patch('rag.retriever.get_vectorstore', return_value=vs)
|
| 42 |
from rag.retriever import retrieve_documents
|
| 43 |
|
| 44 |
docs = retrieve_documents("What does the Pro plan cost?", k=1)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
assert len(docs) > 0
|
| 49 |
-
|
| 50 |
-
|
| 51 |
assert "AutoStream" in docs[0] or "Pro Plan" in docs[0] or "$79/month" in docs[0]
|
|
|
|
| 9 |
|
| 10 |
class MockEmbedding(Embeddings):
|
| 11 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 12 |
+
|
| 13 |
return [[0.0] * 1536 for _ in texts]
|
| 14 |
|
| 15 |
def embed_query(self, text: str) -> List[float]:
|
| 16 |
return [0.0] * 1536
|
| 17 |
|
| 18 |
def test_rag_pipeline_loads_and_retrieves(mocker, tmp_path):
|
| 19 |
+
|
| 20 |
kb_file = tmp_path / "knowledge_base.md"
|
| 21 |
kb_file.write_text("""
|
| 22 |
# AutoStream Pricing & Features
|
|
|
|
| 28 |
* AI captions included
|
| 29 |
""")
|
| 30 |
|
| 31 |
+
|
| 32 |
mocker.patch('rag.vectorstore.get_embeddings', return_value=MockEmbedding())
|
|
|
|
| 33 |
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
vs = build_vectorstore(str(kb_file))
|
| 38 |
assert vs is not None
|
| 39 |
|
| 40 |
+
|
| 41 |
mocker.patch('rag.retriever.get_vectorstore', return_value=vs)
|
| 42 |
from rag.retriever import retrieve_documents
|
| 43 |
|
| 44 |
docs = retrieve_documents("What does the Pro plan cost?", k=1)
|
| 45 |
|
| 46 |
+
|
| 47 |
+
|
| 48 |
assert len(docs) > 0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
assert "AutoStream" in docs[0] or "Pro Plan" in docs[0] or "$79/month" in docs[0]
|
tests/test_tool_execution.py
CHANGED
|
@@ -12,14 +12,14 @@ def test_tool_execution_missing_fields(mocker):
|
|
| 12 |
retrieved_documents=[],
|
| 13 |
user_name="Alex",
|
| 14 |
user_email="alex@email.com",
|
| 15 |
-
creator_platform=None,
|
| 16 |
lead_ready=True,
|
| 17 |
response=""
|
| 18 |
)
|
| 19 |
|
| 20 |
result = execute_tool(state)
|
| 21 |
|
| 22 |
-
|
| 23 |
mock_tool.assert_not_called()
|
| 24 |
assert "Error" in result["response"]
|
| 25 |
|
|
@@ -40,6 +40,6 @@ def test_tool_execution_all_fields(mocker):
|
|
| 40 |
|
| 41 |
result = execute_tool(state)
|
| 42 |
|
| 43 |
-
|
| 44 |
mock_tool.assert_called_once_with("Alex", "alex@email.com", "YouTube")
|
| 45 |
assert "Thanks Alex" in result["response"]
|
|
|
|
| 12 |
retrieved_documents=[],
|
| 13 |
user_name="Alex",
|
| 14 |
user_email="alex@email.com",
|
| 15 |
+
creator_platform=None,
|
| 16 |
lead_ready=True,
|
| 17 |
response=""
|
| 18 |
)
|
| 19 |
|
| 20 |
result = execute_tool(state)
|
| 21 |
|
| 22 |
+
|
| 23 |
mock_tool.assert_not_called()
|
| 24 |
assert "Error" in result["response"]
|
| 25 |
|
|
|
|
| 40 |
|
| 41 |
result = execute_tool(state)
|
| 42 |
|
| 43 |
+
|
| 44 |
mock_tool.assert_called_once_with("Alex", "alex@email.com", "YouTube")
|
| 45 |
assert "Thanks Alex" in result["response"]
|