Spaces:
Sleeping
Sleeping
| import os | |
| from app.core.config import settings | |
| from fastapi import FastAPI, HTTPException, Depends | |
| from pydantic import BaseModel, EmailStr | |
| from typing import Optional, Dict, Any, TypedDict, Annotated, Sequence | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| from langgraph.graph import add_messages | |
| from langgraph.types import Command | |
| import logging | |
| from app.graph import graph | |
| from app.state.state import EmailAgentState | |
| from app.database.connection import get_session | |
| from app.database.utils import get_or_create_user | |
| from sqlalchemy.orm import Session | |
| from app.database.connection import SessionLocal | |
| from fastapi import Request | |
| from app.database.models import User | |
| from app.utils.email_encode import encode_email_for_namespace | |
| from app.core.auth import create_access_token,get_current_user | |
| import traceback | |
| import gradio as gr | |
| from app.gradio_ui import build_demo | |
| # CREATE GMAIL AUTH FILES FROM HF SECRETS | |
| logger = logging.getLogger(__name__) | |
| def get_session(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| app = FastAPI(title="AI Email Agent API") | |
| # --- Schemas --- | |
| class EmailProcessRequest(BaseModel): | |
| thread_id: str | |
| sender_email_id: EmailStr | |
| sender_subject: str | |
| sender_email_body: str | |
| class ReviewActionRequest(BaseModel): | |
| thread_id: str | |
| status: str # "approved" or "rejected" | |
| feedback: Optional[str] = None | |
| class SendEmailRequest(BaseModel): | |
| thread_id: str | |
| human_message: str | |
| # --- Helper Functions --- | |
| def parse_interrupt(final_state: Dict[str, Any]) -> Optional[Dict[str, Any]]: | |
| """Parse interrupt from graph state.""" | |
| if "__interrupt__" not in final_state: | |
| return None | |
| interrupt_state = final_state.get("__interrupt__") | |
| if not interrupt_state: | |
| return None | |
| interrupt = interrupt_state[0] | |
| value = getattr(interrupt, "value", {}) or {} | |
| return { | |
| "action": value.get("action"), | |
| "data": value.get("data", {}) | |
| } | |
| # --- Endpoints --- | |
| def get_user_data(user_email: EmailStr, db: Session = Depends(get_session)): | |
| """Get user data by email.""" | |
| user = get_or_create_user(db, user_email) | |
| token = create_access_token({ | |
| "id": user.id, | |
| "email": user_email | |
| }) | |
| return {"user_id": str(user.id), "email": user.email, "token": token} | |
| def process_email(request: EmailProcessRequest, db: Session = Depends(get_session), current_user: User = Depends(get_current_user)) -> Dict[str, Any]: | |
| """Process email through the graph pipeline.""" | |
| try: | |
| thread_id = request.thread_id | |
| config = { | |
| "configurable": { | |
| "thread_id": thread_id, | |
| "user_id": str(current_user.id), | |
| "sender_email_id": encode_email_for_namespace(request.sender_email_id ), | |
| } | |
| } | |
| input_data = { | |
| "user_email_id": current_user.email, | |
| "user_id": current_user.id, | |
| "user_name": "Atharva", | |
| "sender_email_id": request.sender_email_id, | |
| "sender_subject": request.sender_subject, | |
| "sender_email_body": request.sender_email_body, | |
| } | |
| try: | |
| final_state = graph.invoke(input_data, config=config) | |
| except Exception as e: | |
| logger.error(f"Error invoking graph: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Graph invocation error: {str(e)}") | |
| if final_state.get('triage_label') == "FOLLOW_UP_REQUIRED": | |
| if "__interrupt__" in final_state and not final_state.get("draft_id"): | |
| parsed_interrupt = parse_interrupt(final_state) | |
| if parsed_interrupt: | |
| data = parsed_interrupt["data"] | |
| return { | |
| "status": "needs_review", | |
| "thread_id": thread_id, | |
| "messages": final_state.get("messages", []), | |
| "triage_label": final_state.get("triage_label"), | |
| "action": parsed_interrupt["action"], | |
| "email_draft": { | |
| "to": data.get("to"), | |
| "subject": data.get("subject"), | |
| "body": data.get("body"), | |
| } | |
| } | |
| return { | |
| "thread_id": thread_id, | |
| "triage_label": final_state.get("triage_label"), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing email: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def review_action(request: ReviewActionRequest,db: Session = Depends(get_session), current_user: User = Depends(get_current_user)) -> Dict[str, Any]: | |
| """Resume graph execution based on user review.""" | |
| try: | |
| config = { | |
| "configurable": { | |
| "thread_id": request.thread_id, | |
| "user_id": str(current_user.id), | |
| } | |
| } | |
| state=graph.get_state(config) | |
| sender_email_id=state.values['sender_email_id'] | |
| config["configurable"]["sender_email_id"] = encode_email_for_namespace(sender_email_id) | |
| if request.status == "rejected": | |
| payload = Command(resume={ | |
| "status": "rejected", | |
| "feedback": request.feedback | |
| }) | |
| elif request.status == "approved": | |
| payload = Command(resume={ | |
| "status": "approved" | |
| }) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid status") | |
| intermediate_state = graph.invoke(payload, config=config) | |
| # Still in review phase | |
| if "__interrupt__" in intermediate_state and not intermediate_state.get("draft_id"): | |
| parsed_interrupt = parse_interrupt(intermediate_state) | |
| if parsed_interrupt: | |
| data = parsed_interrupt["data"] | |
| return { | |
| "status": "needs_review", | |
| "thread_id": request.thread_id, | |
| "triage_label": intermediate_state.get("triage_label"), | |
| "action": parsed_interrupt["action"], | |
| "email_draft": { | |
| "to": data.get("to"), | |
| "subject": data.get("subject"), | |
| "body": data.get("body"), | |
| } | |
| } | |
| # Draft created, review complete | |
| if intermediate_state.get("draft_id"): | |
| return { | |
| "thread_id": request.thread_id, | |
| "draft_id": intermediate_state["draft_id"], | |
| "messages": intermediate_state.get("messages", []), | |
| "reply_subject": intermediate_state.get("reply_subject"), | |
| "reply_email_body": intermediate_state.get("reply_email_body"), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in review action: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def send_email(request: SendEmailRequest,db: Session = Depends(get_session),current_user: User = Depends(get_current_user)) -> Dict[str, Any]: | |
| config = { | |
| "configurable": { | |
| "thread_id": request.thread_id, | |
| "user_id": str(current_user.id), | |
| } | |
| } | |
| state=graph.get_state(config) | |
| sender_email_id=state.values['sender_email_id'] | |
| config["configurable"]["sender_email_id"] = encode_email_for_namespace(sender_email_id) | |
| graph.update_state( | |
| config, | |
| {"messages": [HumanMessage(content=request.human_message)]}, | |
| as_node="prepare_context_node" | |
| ) | |
| final_state = graph.invoke(None, config=config) | |
| return { | |
| "thread_id": request.thread_id, | |
| "messages": final_state.get("messages", []), | |
| "sent_message_id": final_state.get("sent_message_id") | |
| } | |
| # adjust import path to wherever you place gradio_ui.py | |
| demo = build_demo() | |
| # Mounts the Gradio UI at /ui → e.g. https://vinit006-emailagentwithmemory.hf.space/ui | |
| app = gr.mount_gradio_app(app, demo, path="/ui") | |
| # used for local testing | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="127.0.0.1", port=8080) |