Gaykar's picture
chnages added gradio
972c80f
Raw
History Blame Contribute Delete
8.58 kB
import os
from app.core.config import settings
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, EmailStr
from typing import Optional, Dict, Any, TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph import add_messages
from langgraph.types import Command
import logging
from app.graph import graph
from app.state.state import EmailAgentState
from app.database.connection import get_session
from app.database.utils import get_or_create_user
from sqlalchemy.orm import Session
from app.database.connection import SessionLocal
from fastapi import Request
from app.database.models import User
from app.utils.email_encode import encode_email_for_namespace
from app.core.auth import create_access_token,get_current_user
import traceback
import gradio as gr
from app.gradio_ui import build_demo
# CREATE GMAIL AUTH FILES FROM HF SECRETS
logger = logging.getLogger(__name__)
def get_session():
db = SessionLocal()
try:
yield db
finally:
db.close()
app = FastAPI(title="AI Email Agent API")
# --- Schemas ---
class EmailProcessRequest(BaseModel):
thread_id: str
sender_email_id: EmailStr
sender_subject: str
sender_email_body: str
class ReviewActionRequest(BaseModel):
thread_id: str
status: str # "approved" or "rejected"
feedback: Optional[str] = None
class SendEmailRequest(BaseModel):
thread_id: str
human_message: str
# --- Helper Functions ---
def parse_interrupt(final_state: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Parse interrupt from graph state."""
if "__interrupt__" not in final_state:
return None
interrupt_state = final_state.get("__interrupt__")
if not interrupt_state:
return None
interrupt = interrupt_state[0]
value = getattr(interrupt, "value", {}) or {}
return {
"action": value.get("action"),
"data": value.get("data", {})
}
# --- Endpoints ---
@app.post("/get-user-data")
def get_user_data(user_email: EmailStr, db: Session = Depends(get_session)):
"""Get user data by email."""
user = get_or_create_user(db, user_email)
token = create_access_token({
"id": user.id,
"email": user_email
})
return {"user_id": str(user.id), "email": user.email, "token": token}
@app.post("/process-email")
def process_email(request: EmailProcessRequest, db: Session = Depends(get_session), current_user: User = Depends(get_current_user)) -> Dict[str, Any]:
"""Process email through the graph pipeline."""
try:
thread_id = request.thread_id
config = {
"configurable": {
"thread_id": thread_id,
"user_id": str(current_user.id),
"sender_email_id": encode_email_for_namespace(request.sender_email_id ),
}
}
input_data = {
"user_email_id": current_user.email,
"user_id": current_user.id,
"user_name": "Atharva",
"sender_email_id": request.sender_email_id,
"sender_subject": request.sender_subject,
"sender_email_body": request.sender_email_body,
}
try:
final_state = graph.invoke(input_data, config=config)
except Exception as e:
logger.error(f"Error invoking graph: {str(e)}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Graph invocation error: {str(e)}")
if final_state.get('triage_label') == "FOLLOW_UP_REQUIRED":
if "__interrupt__" in final_state and not final_state.get("draft_id"):
parsed_interrupt = parse_interrupt(final_state)
if parsed_interrupt:
data = parsed_interrupt["data"]
return {
"status": "needs_review",
"thread_id": thread_id,
"messages": final_state.get("messages", []),
"triage_label": final_state.get("triage_label"),
"action": parsed_interrupt["action"],
"email_draft": {
"to": data.get("to"),
"subject": data.get("subject"),
"body": data.get("body"),
}
}
return {
"thread_id": thread_id,
"triage_label": final_state.get("triage_label"),
}
except Exception as e:
logger.error(f"Error processing email: {str(e)}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/review-action")
def review_action(request: ReviewActionRequest,db: Session = Depends(get_session), current_user: User = Depends(get_current_user)) -> Dict[str, Any]:
"""Resume graph execution based on user review."""
try:
config = {
"configurable": {
"thread_id": request.thread_id,
"user_id": str(current_user.id),
}
}
state=graph.get_state(config)
sender_email_id=state.values['sender_email_id']
config["configurable"]["sender_email_id"] = encode_email_for_namespace(sender_email_id)
if request.status == "rejected":
payload = Command(resume={
"status": "rejected",
"feedback": request.feedback
})
elif request.status == "approved":
payload = Command(resume={
"status": "approved"
})
else:
raise HTTPException(status_code=400, detail="Invalid status")
intermediate_state = graph.invoke(payload, config=config)
# Still in review phase
if "__interrupt__" in intermediate_state and not intermediate_state.get("draft_id"):
parsed_interrupt = parse_interrupt(intermediate_state)
if parsed_interrupt:
data = parsed_interrupt["data"]
return {
"status": "needs_review",
"thread_id": request.thread_id,
"triage_label": intermediate_state.get("triage_label"),
"action": parsed_interrupt["action"],
"email_draft": {
"to": data.get("to"),
"subject": data.get("subject"),
"body": data.get("body"),
}
}
# Draft created, review complete
if intermediate_state.get("draft_id"):
return {
"thread_id": request.thread_id,
"draft_id": intermediate_state["draft_id"],
"messages": intermediate_state.get("messages", []),
"reply_subject": intermediate_state.get("reply_subject"),
"reply_email_body": intermediate_state.get("reply_email_body"),
}
except Exception as e:
logger.error(f"Error in review action: {str(e)}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/send_email")
def send_email(request: SendEmailRequest,db: Session = Depends(get_session),current_user: User = Depends(get_current_user)) -> Dict[str, Any]:
config = {
"configurable": {
"thread_id": request.thread_id,
"user_id": str(current_user.id),
}
}
state=graph.get_state(config)
sender_email_id=state.values['sender_email_id']
config["configurable"]["sender_email_id"] = encode_email_for_namespace(sender_email_id)
graph.update_state(
config,
{"messages": [HumanMessage(content=request.human_message)]},
as_node="prepare_context_node"
)
final_state = graph.invoke(None, config=config)
return {
"thread_id": request.thread_id,
"messages": final_state.get("messages", []),
"sent_message_id": final_state.get("sent_message_id")
}
# adjust import path to wherever you place gradio_ui.py
demo = build_demo()
# Mounts the Gradio UI at /ui → e.g. https://vinit006-emailagentwithmemory.hf.space/ui
app = gr.mount_gradio_app(app, demo, path="/ui")
# used for local testing
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="127.0.0.1", port=8080)