Spaces:
Runtime error
Runtime error
added summarze agent
Browse files- app/agents/context_agent.py +2 -1
- app/agents/summarizer_agent.py +8 -0
- app/graph.py +26 -12
- app/nodes/check_token_count_node.py +36 -0
- app/nodes/safety_check_node.py +6 -4
- app/nodes/summarise_email_body_node.py +36 -0
- app/prompts/summarizer_agent_prompt.py +25 -0
- app/state/state.py +5 -2
app/agents/context_agent.py
CHANGED
|
@@ -4,6 +4,7 @@ from langchain_groq import ChatGroq
|
|
| 4 |
from app.prompts.context_agent_prompt import context_agent_template
|
| 5 |
from app.tools.context_agent_tools import context_agent_tools
|
| 6 |
from typing import Any
|
|
|
|
| 7 |
|
| 8 |
context_agent = create_agent(
|
| 9 |
model=ChatGroq(
|
|
@@ -11,7 +12,7 @@ context_agent = create_agent(
|
|
| 11 |
temperature=0.1,
|
| 12 |
),
|
| 13 |
tools=context_agent_tools,
|
| 14 |
-
store=
|
| 15 |
middleware=[
|
| 16 |
ToolCallLimitMiddleware[Any,None](
|
| 17 |
tool_name="search_memory",
|
|
|
|
| 4 |
from app.prompts.context_agent_prompt import context_agent_template
|
| 5 |
from app.tools.context_agent_tools import context_agent_tools
|
| 6 |
from typing import Any
|
| 7 |
+
from app.agent_memory_store import memory_store
|
| 8 |
|
| 9 |
context_agent = create_agent(
|
| 10 |
model=ChatGroq(
|
|
|
|
| 12 |
temperature=0.1,
|
| 13 |
),
|
| 14 |
tools=context_agent_tools,
|
| 15 |
+
store=memory_store,
|
| 16 |
middleware=[
|
| 17 |
ToolCallLimitMiddleware[Any,None](
|
| 18 |
tool_name="search_memory",
|
app/agents/summarizer_agent.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_groq import ChatGroq
|
| 2 |
+
|
| 3 |
+
summarizer_agent=ChatGroq(
|
| 4 |
+
|
| 5 |
+
model="llama-3.3-70b-versatile",
|
| 6 |
+
temperature=0.1
|
| 7 |
+
|
| 8 |
+
)
|
app/graph.py
CHANGED
|
@@ -13,8 +13,9 @@ from app.nodes.store_memory_data_node import store_memory_and_data_node
|
|
| 13 |
from app.nodes.unsafe_email_node import unsafe_emails_node
|
| 14 |
from app.nodes.check_email_exist_node import *
|
| 15 |
from langgraph.types import RetryPolicy
|
|
|
|
|
|
|
| 16 |
from psycopg import OperationalError # Or sqlalchemy.exc.OperationalError depending on your driver
|
| 17 |
-
# imoprt display
|
| 18 |
from IPython.display import Image, display
|
| 19 |
|
| 20 |
|
|
@@ -38,11 +39,13 @@ builder = StateGraph(EmailAgentState)
|
|
| 38 |
# Nodes
|
| 39 |
builder.add_node("safety_check_node", safety_classifier_node)
|
| 40 |
builder.add_node("check_previous_email_exist_node", check_previous_email_exist_node)
|
|
|
|
|
|
|
| 41 |
builder.add_node("triage_node", triage_node)
|
| 42 |
builder.add_node("prepare_context_node", prepare_context_node)
|
| 43 |
builder.add_node("email_writing_agent", email_writing_agent_node)
|
| 44 |
|
| 45 |
-
|
| 46 |
builder.add_node(
|
| 47 |
"store_memory_and_data_node",
|
| 48 |
store_memory_and_data_node,
|
|
@@ -61,10 +64,27 @@ builder.add_node("tools", ToolNode(email_writing_agent_tools), retry_policy=tool
|
|
| 61 |
# Edges (Same as your original logic)
|
| 62 |
builder.add_edge(START, "safety_check_node")
|
| 63 |
|
| 64 |
-
builder.add_conditional_edges(
|
| 65 |
-
"
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
builder.add_conditional_edges("triage_node", route_after_triage, {
|
| 70 |
"check_previous_email_exist_node": "check_previous_email_exist_node",
|
|
@@ -104,9 +124,3 @@ builder.add_edge("parse_node", "store_memory_and_data_node")
|
|
| 104 |
builder.add_edge("store_memory_and_data_node", END)
|
| 105 |
builder.add_edge("unsafe_emails_node", END)
|
| 106 |
builder.add_edge("archive_node", END)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
graph = builder.compile()
|
| 111 |
-
|
| 112 |
-
display(graph)
|
|
|
|
| 13 |
from app.nodes.unsafe_email_node import unsafe_emails_node
|
| 14 |
from app.nodes.check_email_exist_node import *
|
| 15 |
from langgraph.types import RetryPolicy
|
| 16 |
+
from app.nodes.summarise_email_body_node import summarise_email_body_node
|
| 17 |
+
from app.nodes.check_token_count_node import *
|
| 18 |
from psycopg import OperationalError # Or sqlalchemy.exc.OperationalError depending on your driver
|
|
|
|
| 19 |
from IPython.display import Image, display
|
| 20 |
|
| 21 |
|
|
|
|
| 39 |
# Nodes
|
| 40 |
builder.add_node("safety_check_node", safety_classifier_node)
|
| 41 |
builder.add_node("check_previous_email_exist_node", check_previous_email_exist_node)
|
| 42 |
+
builder.add_node("check_token_count_node", check_token_count_node)
|
| 43 |
+
builder.add_node("summarise_email_body_node", summarise_email_body_node)
|
| 44 |
builder.add_node("triage_node", triage_node)
|
| 45 |
builder.add_node("prepare_context_node", prepare_context_node)
|
| 46 |
builder.add_node("email_writing_agent", email_writing_agent_node)
|
| 47 |
|
| 48 |
+
# --- APPLY RETRY POLICIES HERE ---
|
| 49 |
builder.add_node(
|
| 50 |
"store_memory_and_data_node",
|
| 51 |
store_memory_and_data_node,
|
|
|
|
| 64 |
# Edges (Same as your original logic)
|
| 65 |
builder.add_edge(START, "safety_check_node")
|
| 66 |
|
| 67 |
+
builder.add_conditional_edges(
|
| 68 |
+
"safety_check_node",
|
| 69 |
+
after_safety,
|
| 70 |
+
{
|
| 71 |
+
"unsafe": "unsafe_emails_node",
|
| 72 |
+
"safe": "check_token_count_node"
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
builder.add_conditional_edges(
|
| 77 |
+
"check_token_count_node",
|
| 78 |
+
check_token_limit_router,
|
| 79 |
+
{
|
| 80 |
+
"summarize": "summarise_email_body_node",
|
| 81 |
+
"triage": "triage_node"
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
builder.add_edge("summarise_email_body_node", "triage_node")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
|
| 89 |
builder.add_conditional_edges("triage_node", route_after_triage, {
|
| 90 |
"check_previous_email_exist_node": "check_previous_email_exist_node",
|
|
|
|
| 124 |
builder.add_edge("store_memory_and_data_node", END)
|
| 125 |
builder.add_edge("unsafe_emails_node", END)
|
| 126 |
builder.add_edge("archive_node", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/nodes/check_token_count_node.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.state.state import EmailAgentState
|
| 2 |
+
from langchain_groq import ChatGroq
|
| 3 |
+
|
| 4 |
+
llm_for_token_count=ChatGroq(
|
| 5 |
+
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 6 |
+
temperature=0.1,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
def count_input_tokens(subject:str ,body:str) -> int:
|
| 10 |
+
text=subject+body
|
| 11 |
+
return llm_for_token_count.get_num_tokens(text)
|
| 12 |
+
|
| 13 |
+
def check_token_count_node(state: EmailAgentState):
|
| 14 |
+
"""
|
| 15 |
+
This is a formal Node. It calculates tokens and can
|
| 16 |
+
update the state if you add a 'token_count' key to your TypedDict.
|
| 17 |
+
"""
|
| 18 |
+
subject = state.get('sender_subject', "")
|
| 19 |
+
body = state.get('sender_email_body', "")
|
| 20 |
+
|
| 21 |
+
tokens = count_input_tokens(subject, body)
|
| 22 |
+
print(f"--- NODE: Token Count calculated as {tokens} ---")
|
| 23 |
+
|
| 24 |
+
# We return the count so it's stored in the state for the next router to see
|
| 25 |
+
return {"sender_email_token_count": tokens}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check_token_limit_router(state: EmailAgentState):
|
| 29 |
+
"""
|
| 30 |
+
Acts as a router to decide if we need summarization.
|
| 31 |
+
"""
|
| 32 |
+
if state.get("sender_email_token_count", 0) > 100000:
|
| 33 |
+
|
| 34 |
+
return "summarize"
|
| 35 |
+
|
| 36 |
+
return "triage"
|
app/nodes/safety_check_node.py
CHANGED
|
@@ -34,8 +34,10 @@ def safety_classifier_node(state: EmailAgentState) -> dict:
|
|
| 34 |
return {"is_safe": False, "safety_reason": f"API error {e.response.status_code}"}
|
| 35 |
|
| 36 |
|
| 37 |
-
|
| 38 |
def after_safety(state: EmailAgentState) -> str:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return {"is_safe": False, "safety_reason": f"API error {e.response.status_code}"}
|
| 35 |
|
| 36 |
|
|
|
|
| 37 |
def after_safety(state: EmailAgentState) -> str:
|
| 38 |
+
"""
|
| 39 |
+
Only handles safety routing.
|
| 40 |
+
"""
|
| 41 |
+
if not state.get("is_safe"):
|
| 42 |
+
return "unsafe"
|
| 43 |
+
return "safe"
|
app/nodes/summarise_email_body_node.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.state.state import EmailAgentState
|
| 2 |
+
from langchain_core.prompts import PromptTemplate
|
| 3 |
+
from langchain_classic.chains.summarize import load_summarize_chain
|
| 4 |
+
from langchain_classic.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
+
from langchain_classic.prompts import PromptTemplate
|
| 7 |
+
from app.agents.summarizer_agent import summarizer_agent
|
| 8 |
+
from app.prompts.summarizer_agent_prompt import summarize_agent_initial_prompt_template,summarize_agent_refine_template
|
| 9 |
+
def summarise_email_body(body: str):
|
| 10 |
+
# Tip: 100 is very small for an email; 500-1000 is usually better
|
| 11 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 12 |
+
docs = [Document(page_content=t) for t in text_splitter.split_text(body)]
|
| 13 |
+
|
| 14 |
+
chain = load_summarize_chain(
|
| 15 |
+
llm=summarizer_agent,
|
| 16 |
+
chain_type="refine",
|
| 17 |
+
question_prompt=summarize_agent_initial_prompt,
|
| 18 |
+
refine_prompt=summarize_agent_refine_prompt,
|
| 19 |
+
document_variable_name="body" # <--- ADD THIS LINE
|
| 20 |
+
)
|
| 21 |
+
summary = chain.invoke(docs)
|
| 22 |
+
return summary['output_text']
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def summarise_email_body_node(state:EmailAgentState)->dict:
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
subject=state['sender_subject']
|
| 29 |
+
body=state['sender_email_body']
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
summary=summarise_email_body(body)
|
| 34 |
+
|
| 35 |
+
return {"sender_email_body":summary}
|
| 36 |
+
|
app/prompts/summarizer_agent_prompt.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from langchain_core.prompts import PromptTemplate
|
| 4 |
+
|
| 5 |
+
summarize_agent_initial_prompt_template = """Write a summary of the following email body.
|
| 6 |
+
Include links (if any) and key info (req for reply) if given and maintain a email structure showing who sent it:
|
| 7 |
+
|
| 8 |
+
"{body}"
|
| 9 |
+
|
| 10 |
+
CONCISE SUMMARY:"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
summarize_agent_refine_template = """
|
| 14 |
+
We have an existing summary: {existing_answer}
|
| 15 |
+
We have more email content below:
|
| 16 |
+
------------
|
| 17 |
+
{body}
|
| 18 |
+
------------
|
| 19 |
+
Given the new context, refine the summary.
|
| 20 |
+
Ensure you keep the repository links and the sender's info.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
summarize_agent_initial_prompt = PromptTemplate(template=summarize_agent_initial_prompt_template, input_variables=["body"])
|
| 24 |
+
|
| 25 |
+
summarize_agent_refine_prompt = PromptTemplate(template=summarize_agent_refine_template, input_variables=["existing_answer", "body"])
|
app/state/state.py
CHANGED
|
@@ -21,9 +21,12 @@ class EmailAgentState(TypedDict):
|
|
| 21 |
|
| 22 |
sender_email_id: str
|
| 23 |
|
| 24 |
-
sender_subject: str
|
|
|
|
| 25 |
|
| 26 |
-
user_name: str
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Safety node output
|
| 29 |
is_safe: Optional[bool]
|
|
|
|
| 21 |
|
| 22 |
sender_email_id: str
|
| 23 |
|
| 24 |
+
sender_subject: str
|
| 25 |
+
|
| 26 |
|
| 27 |
+
user_name: str
|
| 28 |
+
|
| 29 |
+
sender_email_token_count: Optional[int]
|
| 30 |
|
| 31 |
# Safety node output
|
| 32 |
is_safe: Optional[bool]
|