from fastapi import FastAPI, HTTPException, Depends from pydantic import BaseModel, EmailStr from typing import Optional, Dict, Any, TypedDict, Annotated, Sequence from langchain_core.messages import BaseMessage, HumanMessage from langgraph.graph import add_messages from langgraph.types import Command import uuid import logging from app.graph import graph from app.state.state import EmailAgentState from app.database.connection import get_session from app.database.utils import get_or_create_user from sqlalchemy.orm import Session from app.database.connection import SessionLocal logger = logging.getLogger(__name__) def get_session(): db = SessionLocal() try: yield db finally: db.close() app = FastAPI(title="AI Email Agent API") # --- Schemas --- class EmailProcessRequest(BaseModel): thread_id: str user_email: EmailStr sender_email_id: EmailStr sender_subject: str sender_email_body: str class ReviewActionRequest(BaseModel): thread_id: str user_id: str status: str # "approved" or "rejected" feedback: Optional[str] = None class SendEmailRequest(BaseModel): thread_id: str user_id: str human_message: str # --- Helper Functions --- def parse_interrupt(final_state: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Parse interrupt from graph state.""" if "__interrupt__" not in final_state: return None interrupt_state = final_state.get("__interrupt__") if not interrupt_state: return None interrupt = interrupt_state[0] value = getattr(interrupt, "value", {}) or {} return { "action": value.get("action"), "data": value.get("data", {}) } # --- Endpoints --- @app.post("/process-email") def process_email(request: EmailProcessRequest, db: Session = Depends(get_session)) -> Dict[str, Any]: """Process email through the graph pipeline.""" try: user = get_or_create_user(db, request.user_email) thread_id = request.thread_id config = { "configurable": { "thread_id": thread_id, "user_id": str(user.id) } } input_data = { "user_email_id": request.user_email, "user_id": user.id, "user_name": "Atharva", "sender_email_id": request.sender_email_id, "sender_subject": request.sender_subject, "sender_email_body": request.sender_email_body, } final_state = graph.invoke(input_data, config=config) if final_state.get('triage_label') == "FOLLOW_UP_REQUIRED": if "__interrupt__" in final_state and not final_state.get("draft_id"): parsed_interrupt = parse_interrupt(final_state) if parsed_interrupt: data = parsed_interrupt["data"] return { "status": "needs_review", "thread_id": thread_id, "messages": final_state.get("messages", []), "triage_label": final_state.get("triage_label"), "action": parsed_interrupt["action"], "email_draft": { "to": data.get("to"), "subject": data.get("subject"), "body": data.get("body"), } } return { "thread_id": thread_id, "triage_label": final_state.get("triage_label"), } except Exception as e: logger.error(f"Error processing email: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/review-action") def review_action(request: ReviewActionRequest) -> Dict[str, Any]: """Resume graph execution based on user review.""" try: config = { "configurable": { "thread_id": request.thread_id, "user_id": request.user_id } } if request.status == "rejected": payload = Command(resume={ "status": "rejected", "feedback": request.feedback }) elif request.status == "approved": payload = Command(resume={ "status": "approved" }) else: raise HTTPException(status_code=400, detail="Invalid status") intermediate_state = graph.invoke(payload, config=config) # Still in review phase if "__interrupt__" in intermediate_state and not intermediate_state.get("draft_id"): parsed_interrupt = parse_interrupt(intermediate_state) if parsed_interrupt: data = parsed_interrupt["data"] return { "status": "needs_review", "thread_id": request.thread_id, "triage_label": intermediate_state.get("triage_label"), "action": parsed_interrupt["action"], "email_draft": { "to": data.get("to"), "subject": data.get("subject"), "body": data.get("body"), } } # Draft created, review complete if intermediate_state.get("draft_id"): return { "thread_id": request.thread_id, "draft_id": intermediate_state["draft_id"], "messages": intermediate_state.get("messages", []), "reply_subject": intermediate_state.get("reply_subject"), "reply_email_body": intermediate_state.get("reply_email_body"), } except Exception as e: logger.error(f"Error in review action: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/send_email") def send_email(request: SendEmailRequest) -> Dict[str, Any]: config = { "configurable": { "thread_id": request.thread_id, "user_id": request.user_id } } graph.update_state( config, {"messages": [HumanMessage(content=request.human_message)]}, as_node="prepare_context_node" ) final_state = graph.invoke(None, config=config) return { "thread_id": request.thread_id, "messages": final_state.get("messages", []), "sent_message_id": final_state.get("sent_message_id") } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)