test / src /synthesize.py
mohamednaim's picture
es
90631ee verified
import os
from typing import Literal,Annotated,Callable
from typing_extensions import Annotated, TypedDict, Literal
from typing_extensions import TypedDict
from langchain.tools import tool
from langgraph.graph import StateGraph,MessagesState,START,END
from langgraph.types import Send
from pydantic import BaseModel,Field
from langchain.chat_models import init_chat_model
from langchain.messages import HumanMessage,SystemMessage
import operator
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from teacher import teacher_agent
from teacher_assistant import teacher_assistant_agent
from planner import planner_agent
from config import *
ROUTER_SYSTEM_PROMPT = """
You are a routing agent. Analyze the query and return ONLY this JSON:
{"classifications": [{"source": "teacher", "query": "the question"}]}
Rules:
- "teacher" → theory, concepts, explanations, definitions
- "teacher_assistant" → code, debugging, implementation, technical errors
- "planner" → roadmaps, study plans, schedules, learning paths
- Pick ONE agent only
- Return ONLY the JSON, nothing else
If a request requires more agents than are currently assigned, assign additional agents to it
Valid source values: "teacher", "teacher_assistant", "planner"
"""
# ================================================================
# STATE SCHEMAS — single source of truth, consistent key names
# ================================================================
class AgentInput(TypedDict):
"""Input passed to each agent node"""
query: str
class AgentOutput(TypedDict):
"""Output collected from each agent node"""
source: str
result: str
class Classification(TypedDict):
"""A single routing decision"""
source: Literal["teacher", "teacher_assistant", "planner"]
query: str
class RouterState(TypedDict):
query: str
classifications: list[Classification]
results: Annotated[list[AgentOutput], operator.add]
final_answer: str
class ClassificationResult(BaseModel):
"""Structured output from the classifier LLM"""
classifications: list[Classification] = Field(
description="List of agents to invoke with their targeted sub-questions"
)
SMALL_TALK = {"hi", "hello", "hey", "thanks", "thank you", "ok", "okay", "bye", "good", "nice", "cool"}
# ================================================================
# CLASSIFIER NODE — fixed
# ================================================================
# def classify_query(state: RouterState) -> dict:
# query = state["query"].strip().lower()
# # bypass LLM entirely for small talk
# if query in SMALL_TALK or len(query.split()) <= 1:
# return {
# "classifications": [],
# "final_answer": "Hello! How can I help you today?"
# }
# structured_llm = router_llm.with_structured_output(ClassificationResult)
# result = structured_llm.invoke([
# {"role": "system", "content": ROUTER_SYSTEM_PROMPT},
# {"role": "user", "content": state["query"]}
# ])
# return {"classifications": result.classifications}
import json
import re
def classify_query(state: RouterState) -> dict:
query = state["query"].strip().lower()
# ✅ small talk bypass
if query in SMALL_TALK or len(query.split()) <= 1:
return {
"classifications": [],
"final_answer": "Hello! How can I help you today?"
}
# ✅ بدل with_structured_output — استخدم plain invoke + JSON parsing
messages = [
{
"role": "system",
"content": ROUTER_SYSTEM_PROMPT + """
CRITICAL: Return ONLY a valid JSON object. No markdown, no explanation, no code blocks.
Exactly this format:
{"classifications": [{"source": "teacher", "query": "the question"}]}
Valid source values: "teacher", "teacher_assistant", "planner"
"""
},
{
"role": "user",
"content": state["query"]
}
]
try:
response = model2.invoke(messages)
raw = response.content.strip()
# ✅ نظف الـ response من markdown لو موجود
raw = re.sub(r"```json|```", "", raw).strip()
# ✅ parse الـ JSON
parsed = json.loads(raw)
classifications = parsed.get("classifications", [])
# ✅ validate
valid_sources = {"teacher", "teacher_assistant", "planner"}
classifications = [
c for c in classifications
if c.get("source") in valid_sources and c.get("query")
]
# ✅ fallback لو فاضي
if not classifications:
classifications = [{"source": "teacher", "query": state["query"]}]
print(f"🧭 [Router] → {[c['source'] for c in classifications]}")
return {"classifications": classifications}
except json.JSONDecodeError as e:
print(f"⚠️ [Router] JSON parse failed: {e} — raw: {raw[:100]}")
# ✅ fallback — روح للـ teacher دايماً
return {
"classifications": [{"source": "teacher", "query": state["query"]}]
}
except Exception as e:
print(f"❌ [Router] Error: {e}")
return {
"classifications": [{"source": "teacher", "query": state["query"]}]
}
# ================================================================
# ROUTING FUNCTION
# ================================================================
AGENT_NODE_MAP = {
"teacher": "teacher", #
"teacher_assistant": "teacher_assistant",
"planner": "planner",
}
def route_to_agents(state: RouterState) -> list[Send]:
return [
Send(AGENT_NODE_MAP.get(c["source"], c["source"]), {"query": c["query"]})
for c in state["classifications"]
]
def query_teacher(state: AgentInput) -> dict:
print(" [Teacher] Agent started — query:", state["query"][:50])
config = {"configurable": {"thread_id": "teacher_session"}}
full_response = ""
result = teacher_agent.invoke(
{"messages": [{"role": "user", "content": state["query"]}]},
config=config
)
messages = result.get("messages", [])
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content:
if not hasattr(msg, "tool_calls") or not msg.tool_calls:
full_response = msg.content
break
print(f" [Teacher] Agent finished — response length: {len(full_response)}")
return {"results": [{"source": "teacher", "result": full_response}]}
def query_teacher_assistant(state: AgentInput) -> dict:
print(" [Teacher_Assistant] Agent started — query:", state["query"][:50])
config = {"configurable": {"thread_id": "assistant_session"}}
full_response = ""
result = teacher_assistant_agent.invoke(
{"messages": [{"role": "user", "content": state["query"]}]},
config=config
)
messages = result.get("messages", [])
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content:
if not hasattr(msg, "tool_calls") or not msg.tool_calls:
full_response = msg.content
break
print(f" [Teacher_Assistant] Agent finished — response length: {len(full_response)}")
return {"results": [{"source": "teacher_assistant", "result": full_response}]}
def query_planner(state: AgentInput) -> dict:
print(" [Planner] Agent started — query:", state["query"][:50])
config = {"configurable": {"thread_id": "planner_session"}}
full_response = ""
result = planner_agent.invoke(
{"messages": [{"role": "user", "content": state["query"]}]},
config=config
)
messages = result.get("messages", [])
for msg in reversed(messages):
if hasattr(msg, "content") and msg.content:
if not hasattr(msg, "tool_calls") or not msg.tool_calls:
full_response = msg.content
break
print(f"[Planner] Agent finished — response length: {len(full_response)}")
return {"results": [{"source": "planner", "result": full_response}]}
# ================================================================
# SYNTHESIZE — fixed
# ================================================================
def synthesize_results(state: RouterState) -> dict:
print(f"\n [Synthesize] received {len(state.get('results', []))} result(s)")
for r in state.get("results", []):
print(f" - {r['source']}: {len(r['result'])} chars")
# already answered upstream (small talk)
if state.get("final_answer"):
return {"final_answer": state["final_answer"]}
results = state.get("results", [])
if not results:
return {"final_answer": "I could not find an answer. Please try rephrasing."}
# single result — return directly, no synthesis needed
if len(results) == 1:
return {"final_answer": results[0]["result"]}
# multiple results — synthesize using llm
formatted = "\n\n---\n\n".join([
f"[{r['source'].upper()}]\n{r['result']}"
for r in results
])
messages = [
{
"role": "system",
"content": (
"You are a synthesis expert. "
"You will receive outputs from multiple educational agents. "
"Combine them into one clear, unified, non-redundant response. "
"Do not mention agent names. Write all agent outputs."
"Do not summarize or remove any details and do not add any new information from your data."
)
},
{
"role": "user",
"content": f"Original question: {state['query']}\n\n{formatted}"
}
]
full_response = ""
for chunk in model2.stream(messages):
if chunk.content:
full_response += chunk.content
return {"final_answer": full_response}
workflow = (
StateGraph(RouterState)
.add_node("classify", classify_query)
.add_node("teacher", query_teacher)
.add_node("teacher_assistant", query_teacher_assistant)
.add_node("planner", query_planner)
.add_node("synthesize", synthesize_results)
.add_edge(START, "classify")
.add_conditional_edges(
"classify",
route_to_agents,
["teacher", "teacher_assistant", "planner", "synthesize"],
)
.add_edge("teacher", "synthesize")
.add_edge("teacher_assistant", "synthesize")
.add_edge("planner", "synthesize")
.add_edge("synthesize", END)
.compile()
)