Gaykar commited on
Commit
4d860a2
·
1 Parent(s): 814aa17
app/database/connection.py CHANGED
@@ -26,7 +26,11 @@ engine = create_engine(
26
 
27
  SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
28
 
29
- def get_session() -> Session:
30
- return SessionLocal()
 
 
 
 
31
 
32
 
 
26
 
27
  SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
28
 
29
+ def get_session():
30
+ db = SessionLocal()
31
+ try:
32
+ yield db
33
+ finally:
34
+ db.close()
35
 
36
 
app/graph.py CHANGED
@@ -41,6 +41,9 @@ tool_node_retry_policy = RetryPolicy(
41
 
42
  builder = StateGraph(EmailAgentState)
43
 
 
 
 
44
  # Nodes
45
  builder.add_node("safety_check_node", safety_classifier_node)
46
  builder.add_node("check_previous_email_exist_node", check_previous_email_exist_node)
@@ -63,9 +66,10 @@ builder.add_node(
63
  )
64
 
65
  builder.add_node("archive_node", archive_node,retry=db_retry_policy)
66
- builder.add_node("parse_node", parse_response_node)
67
- builder.add_node("tools", ToolNode(email_writing_agent_tools), retry_policy=tool_node_retry_policy)
68
 
 
 
 
69
  builder.add_edge(START, "safety_check_node")
70
 
71
  builder.add_conditional_edges(
@@ -119,24 +123,16 @@ builder.add_conditional_edges(
119
  "tools",
120
  route_after_tools,
121
  {
122
- "parse_node": "parse_node",
123
  "email_writing_agent": "email_writing_agent"
124
  }
125
  )
126
 
127
- builder.add_edge("parse_node", "store_memory_and_data_node")
128
  builder.add_edge("store_memory_and_data_node", END)
129
  builder.add_edge("unsafe_emails_node", END)
130
  builder.add_edge("archive_node", END)
131
 
132
-
133
-
134
-
135
-
136
-
137
-
138
-
139
-
140
  toolkit = GmailToolkit()
141
 
142
 
 
41
 
42
  builder = StateGraph(EmailAgentState)
43
 
44
+ # Nodes
45
+ builder = StateGraph(EmailAgentState)
46
+
47
  # Nodes
48
  builder.add_node("safety_check_node", safety_classifier_node)
49
  builder.add_node("check_previous_email_exist_node", check_previous_email_exist_node)
 
66
  )
67
 
68
  builder.add_node("archive_node", archive_node,retry=db_retry_policy)
 
 
69
 
70
+ builder.add_node("tools", ToolNode(tools), retry=tool_node_retry_policy)
71
+
72
+ # Edges (Same as your original logic)
73
  builder.add_edge(START, "safety_check_node")
74
 
75
  builder.add_conditional_edges(
 
123
  "tools",
124
  route_after_tools,
125
  {
126
+ "store_memory_and_data_node": "store_memory_and_data_node",
127
  "email_writing_agent": "email_writing_agent"
128
  }
129
  )
130
 
131
+
132
  builder.add_edge("store_memory_and_data_node", END)
133
  builder.add_edge("unsafe_emails_node", END)
134
  builder.add_edge("archive_node", END)
135
 
 
 
 
 
 
 
 
 
136
  toolkit = GmailToolkit()
137
 
138
 
app/main.py CHANGED
@@ -1,48 +1,171 @@
1
- import time
2
- from psycopg import OperationalError
 
 
 
 
 
 
3
  from app.graph import graph
4
  from app.state.state import EmailAgentState
5
- from app.database.connection import pool
6
- import uuid
 
7
 
 
8
 
9
- config = {"configurable": {"thread_id": str(uuid.uuid4()), "user_id": "1"}}
10
 
11
- input_data: EmailAgentState = {
12
- "user_email_id": "gaykaratharva7@gmail.com",
13
- "user_id": 1,
14
- "user_name": "Atharva",
15
- "sender_email_id": "atharvagaykar36@gmail.com",
16
- "sender_subject": "URGENT: Validation of Hybrid Phishing Detection Model & XGBoost Integration",
17
- "sender_email_body": """Dear Atharva,\r\n\r\nI have completed the integration of the AI-Driven Email Threat Detection pipeline... [truncated for brevity]"""
18
- }
19
 
20
- if __name__ == "__main__":
21
- result = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
- # 3. Retry loop for Neon wake-up/stability
24
- for i in range(3):
25
- try:
26
- print(f"Attempt {i+1}: Invoking graph...")
27
- result = graph.invoke(input_data, config=config)
28
- break # Success! Exit the loop.
29
- except OperationalError as e:
30
- if i < 2:
31
- print("Waiting for Neon database to wake up...")
32
- time.sleep(10) # Increased sleep slightly for Neon cold starts
33
- else:
34
- print("Max retries reached. Database connection failed.")
35
- raise e
36
-
37
- # 4. Output the result
38
- if result:
39
- print("\n--- Graph Execution Result ---")
40
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  except Exception as e:
43
- print(f"An error occurred during execution: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- finally:
46
- # 5. CRITICAL: Close the pool to prevent "cannot join current thread" error
47
- print("Closing connection pool...")
48
- pool.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from pydantic import BaseModel, EmailStr
3
+ from typing import Optional, Dict, Any, TypedDict, Annotated, Sequence
4
+ from langchain_core.messages import BaseMessage
5
+ from langgraph.graph import add_messages
6
+ from langgraph.types import Command
7
+ import uuid
8
+ import logging
9
  from app.graph import graph
10
  from app.state.state import EmailAgentState
11
+ from app.database.connection import get_session
12
+ from app.database.utils import get_or_create_user
13
+ from sqlalchemy.orm import Session
14
 
15
+ logger = logging.getLogger(__name__)
16
 
17
+ app = FastAPI(title="AI Email Agent API")
18
 
19
+ # --- Schemas ---
 
 
 
 
 
 
 
20
 
21
+ class EmailProcessRequest(BaseModel):
22
+ thread_id: str
23
+ user_email: EmailStr
24
+ sender_email_id: EmailStr
25
+ sender_subject: str
26
+ sender_email_body: str
27
+
28
+
29
+ class ReviewActionRequest(BaseModel):
30
+ thread_id: str
31
+ user_id: str
32
+ status: str # "approved" or "rejected"
33
+ feedback: Optional[str] = None
34
+
35
+
36
+ # --- Helper Functions ---
37
+
38
+ def parse_interrupt(final_state: Dict[str, Any]) -> Optional[Dict[str, Any]]:
39
+ """Parse interrupt from graph state."""
40
+ if "__interrupt__" not in final_state:
41
+ return None
42
+
43
+ interrupt_state = final_state.get("__interrupt__")
44
+ if not interrupt_state:
45
+ return None
46
+
47
+ interrupt = interrupt_state[0]
48
+ value = getattr(interrupt, "value", {}) or {}
49
+
50
+ return {
51
+ "action": value.get("action"),
52
+ "data": value.get("data", {})
53
+ }
54
+
55
+
56
+ # --- Endpoints ---
57
+
58
+ @app.post("/process-email")
59
+ def process_email(request: EmailProcessRequest, db: Session = Depends(get_session)) -> Dict[str, Any]:
60
+ """Process email through the graph pipeline."""
61
+
62
  try:
63
+ user = get_or_create_user(db, request.user_email)
64
+
65
+ thread_id = request.thread_id
66
+ config = {
67
+ "configurable": {
68
+ "thread_id": thread_id,
69
+ "user_id": str(user.id)
70
+ }
71
+ }
72
+
73
+ input_data = {
74
+ "user_email_id": request.user_email,
75
+ "user_id": user.id,
76
+ "user_name": "Atharva",
77
+ "sender_email_id": request.sender_email_id,
78
+ "sender_subject": request.sender_subject,
79
+ "sender_email_body": request.sender_email_body,
80
+ }
81
+
82
+ final_state = graph.invoke(input_data, config=config)
83
+
84
+ if final_state.get('triage_label') == "FOLLOW_UP_REQUIRED":
85
+ if "__interrupt__" in final_state and not final_state.get("draft_id"):
86
+ parsed_interrupt = parse_interrupt(final_state)
87
+ if parsed_interrupt:
88
+ data = parsed_interrupt["data"]
89
+
90
+ return {
91
+ "status": "needs_review",
92
+ "thread_id": thread_id,
93
+ "triage_label": final_state.get("triage_label"),
94
+ "action": parsed_interrupt["action"],
95
+ "email_draft": {
96
+ "to": data.get("to"),
97
+ "subject": data.get("subject"),
98
+ "body": data.get("body"),
99
+ }
100
+ }
101
+
102
+ return {
103
+ "thread_id": thread_id,
104
+ "triage_label": final_state.get("triage_label"),
105
+ }
106
 
107
  except Exception as e:
108
+ logger.error(f"Error processing email: {str(e)}")
109
+ raise HTTPException(status_code=500, detail=str(e))
110
+
111
+
112
+ @app.post("/review-action")
113
+ def review_action(request: ReviewActionRequest) -> Dict[str, Any]:
114
+ """Resume graph execution based on user review."""
115
+
116
+ try:
117
+ config = {
118
+ "configurable": {
119
+ "thread_id": request.thread_id,
120
+ "user_id": request.user_id
121
+ }
122
+ }
123
 
124
+ if request.status == "rejected":
125
+ payload = Command(resume={
126
+ "status": "rejected",
127
+ "feedback": request.feedback
128
+ })
129
+ elif request.status == "approved":
130
+ payload = Command(resume={
131
+ "status": "approved"
132
+ })
133
+ else:
134
+ raise HTTPException(status_code=400, detail="Invalid status")
135
+
136
+ final_state = graph.invoke(payload, config=config)
137
+
138
+ # Still in review phase
139
+ if "__interrupt__" in final_state and not final_state.get("draft_id"):
140
+ parsed_interrupt = parse_interrupt(final_state)
141
+ if parsed_interrupt:
142
+ data = parsed_interrupt["data"]
143
+ return {
144
+ "status": "needs_review",
145
+ "thread_id": request.thread_id,
146
+ "triage_label": final_state.get("triage_label"),
147
+ "action": parsed_interrupt["action"],
148
+ "email_draft": {
149
+ "to": data.get("to"),
150
+ "subject": data.get("subject"),
151
+ "body": data.get("body"),
152
+ }
153
+ }
154
+
155
+ # Draft created, review complete
156
+ if final_state.get("draft_id"):
157
+ return {
158
+ "thread_id": request.thread_id,
159
+ "draft_id": final_state["draft_id"],
160
+ "reply_subject": final_state.get("reply_subject"),
161
+ "reply_email_body": final_state.get("reply_email_body"),
162
+ }
163
+
164
+ except Exception as e:
165
+ logger.error(f"Error in review action: {str(e)}")
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+
168
+
169
+ if __name__ == "__main__":
170
+ import uvicorn
171
+ uvicorn.run(app, host="0.0.0.0", port=8000)
app/tools/email_writing_agent_tools.py CHANGED
@@ -3,19 +3,27 @@ from googleapiclient.errors import HttpError
3
  from app.schemas.email_writing_agent_tools_schema import CreateDraftSchema, SendDraftSchema
4
  from langchain.tools import tool
5
  from langchain_google_community import GmailToolkit
6
-
 
 
 
7
 
8
  @tool(args_schema=CreateDraftSchema)
9
- def create_gmail_draft(to: str, subject: str, body: str):
 
 
 
 
 
10
  """Creates a new Gmail draft after human approval."""
11
-
12
  if isinstance(to, list):
13
- if len(to) > 0:
14
- to = str(to[0])
15
- else:
16
- return "ERROR: 'to' parameter is an empty list. Please provide a valid email string."
17
 
18
- # 1. Pause and ask for review
 
 
 
19
  response = interrupt({
20
  "action": "review_draft",
21
  "data": {"to": to, "subject": subject, "body": body}
@@ -24,40 +32,59 @@ def create_gmail_draft(to: str, subject: str, body: str):
24
  toolkit = GmailToolkit()
25
  draft_tool = [t for t in toolkit.get_tools() if t.name == "create_gmail_draft"][0]
26
 
27
- # 2. Handle the response
28
  if response.get("status") == "approved":
29
- reply = draft_tool.invoke({
30
- "message": body,
31
- "to": [to],
32
- "subject": subject
33
- })
34
-
35
- draft_id=reply.split(":")[1].strip()
36
- return f"Successfully created draft : <id>{draft_id}</id> <subject>{subject}</subject> <body>{body}</body> take user permission before submitting"
37
-
 
 
 
 
 
 
 
38
  else:
39
- # Get the feedback from the user response
40
- feedback = response.get("feedback", "User rejected without specific notes.")
41
-
42
- # We return this to the AGENT so it can read it and rewrite the draft
43
- return f"DRAFT REJECTED BY USER. Feedback: {feedback}. Please rewrite the draft based on this feedback and try again."
44
 
45
 
46
-
47
 
48
  @tool(args_schema=SendDraftSchema)
49
- def send_draft_by_id(draft_id: str):
 
 
 
50
  """Sends a finalized Gmail draft by its ID."""
51
  try:
52
  toolkit = GmailToolkit()
53
  result = toolkit.api_resource.users().drafts().send(
54
  userId="me", body={"id": draft_id}
55
  ).execute()
56
- return f"SUCCESS: Sent! a Gmail with ID: <id>{result['id']}</id>"
 
 
 
 
 
 
 
 
 
 
57
  except HttpError as error:
58
- if error.resp.status == 404:
59
- return f"ERROR: Draft ID {draft_id} was not found. Please verify the ID or check if it was already sent."
60
- return f"ERROR: An unexpected error occurred: {error}"
 
 
61
 
62
 
63
  email_writing_agent_tools = [
 
3
  from app.schemas.email_writing_agent_tools_schema import CreateDraftSchema, SendDraftSchema
4
  from langchain.tools import tool
5
  from langchain_google_community import GmailToolkit
6
+ from typing import Annotated, Union
7
+ from langchain_core.tools import InjectedToolCallId, tool
8
+ from langgraph.types import Command
9
+ from langchain_core.messages import SystemMessage, HumanMessage,ToolMessage,AIMessage,BaseMessage
10
 
11
  @tool(args_schema=CreateDraftSchema)
12
+ def create_gmail_draft(
13
+ to: Union[str, list],
14
+ subject: str,
15
+ body: str,
16
+ tool_call_id: Annotated[str, InjectedToolCallId] # Injected ID
17
+ ):
18
  """Creates a new Gmail draft after human approval."""
19
+
20
  if isinstance(to, list):
21
+ to = str(to[0]) if len(to) > 0 else "ERROR"
 
 
 
22
 
23
+ if to == "ERROR":
24
+ return "ERROR: 'to' parameter is empty."
25
+
26
+ # 1. Human-in-the-loop Interrupt
27
  response = interrupt({
28
  "action": "review_draft",
29
  "data": {"to": to, "subject": subject, "body": body}
 
32
  toolkit = GmailToolkit()
33
  draft_tool = [t for t in toolkit.get_tools() if t.name == "create_gmail_draft"][0]
34
 
35
+ # 2. Handle Logic
36
  if response.get("status") == "approved":
37
+ reply = draft_tool.invoke({"message": body, "to": [to], "subject": subject})
38
+ try:
39
+ draft_id = reply.split(":")[1].strip()
40
+ content = f"Successfully created draft: <id>{draft_id}</id> <subject>{subject}</subject> <body>{body}</body>"
41
+
42
+ # UPDATE STATE: Save draft_id directly
43
+ return Command(
44
+ update={
45
+ "draft_id": draft_id,
46
+ "reply_subject": subject,
47
+ "reply_email_body": body,
48
+ "messages": [ToolMessage(content, tool_call_id=tool_call_id)]
49
+ }
50
+ )
51
+ except IndexError:
52
+ return f"Draft created, but response parsing failed: {reply}"
53
  else:
54
+ feedback = response.get("feedback", "User rejected.")
55
+ return f"DRAFT REJECTED BY USER. Feedback: {feedback}. Please rewrite."
 
 
 
56
 
57
 
58
+ #---------------------------------------------------------------------------
59
 
60
  @tool(args_schema=SendDraftSchema)
61
+ def send_draft_by_id(
62
+ draft_id: str,
63
+ tool_call_id: Annotated[str, InjectedToolCallId] # Injected ID
64
+ ):
65
  """Sends a finalized Gmail draft by its ID."""
66
  try:
67
  toolkit = GmailToolkit()
68
  result = toolkit.api_resource.users().drafts().send(
69
  userId="me", body={"id": draft_id}
70
  ).execute()
71
+
72
+ sent_id = result['id']
73
+ content = f"SUCCESS: Sent! a Gmail with ID: <id>{sent_id}</id>"
74
+
75
+ # UPDATE STATE: Save sent_message_id directly
76
+ return Command(
77
+ update={
78
+ "sent_message_id": sent_id,
79
+ "messages": [ToolMessage(content, tool_call_id=tool_call_id)]
80
+ }
81
+ )
82
  except HttpError as error:
83
+ error_msg = f"ERROR: {error}"
84
+ return Command(
85
+ update={"messages": [ToolMessage(error_msg, tool_call_id=tool_call_id)]}
86
+ )
87
+
88
 
89
 
90
  email_writing_agent_tools = [
app/utils/interrupt_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ def parse_interrupt(final_state: dict):
5
+ if "__interrupt__" not in final_state:
6
+ return None
7
+
8
+ interrupt_state = final_state["__interrupt__"]
9
+ if not interrupt_state:
10
+ return None
11
+
12
+ interrupt = interrupt_state[0]
13
+ value = getattr(interrupt, "value", {}) or {}
14
+
15
+ return {
16
+ "action": value.get("action"),
17
+ "data": value.get("data", {})
18
+ }