Policy / src /main.py
vishalkatheriya's picture
Upload 6 files
adf61d6 verified
import logging
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner
from google.genai import types
from agent import root_agent
import callback
logger = logging.getLogger(__name__)
# =========================
# FASTAPI APP
# =========================
app = FastAPI(
title="Leave Policy Assistant",
description="AI agent to answer leave policy questions",
version="1.0.0"
)
# =========================
# ADK SESSION SETUP
# =========================
APP_NAME = "leave_policy_app"
session_service = InMemorySessionService()
runner = Runner(
agent=root_agent,
app_name=APP_NAME,
session_service=session_service,
)
# Optional: persist ADK callback events to Snowflake (set ENABLE_AUDIT_SINK=1 in .env)
if os.getenv("ENABLE_AUDIT_SINK", "").strip().lower() in ("1", "true", "yes"):
try:
from audit_db import SnowflakeAuditSink, ensure_table
callback.set_audit_sink(SnowflakeAuditSink())
ensure_table()
except Exception as e:
import logging
logging.getLogger(__name__).warning("Audit sink not enabled: %s", e)
# =========================
# REQUEST / RESPONSE MODELS
# =========================
class ChatRequest(BaseModel):
user_id: str
message: str
class ChatResponse(BaseModel):
response: str
# =========================
# CHAT ENDPOINT
# =========================
def _text_from_content(content) -> str | None:
"""Extract response text from event content (check all parts)."""
if content is None:
return None
try:
parts = getattr(content, "parts", None)
if not parts:
return None
for part in parts:
text = getattr(part, "text", None)
if text is not None and str(text).strip():
return str(text).strip()
except (AttributeError, TypeError, IndexError):
pass
return None
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
try:
session_id = f"{req.user_id}_session"
session = await session_service.get_session(
app_name=APP_NAME,
user_id=req.user_id,
session_id=session_id,
)
if not session:
await session_service.create_session(
app_name=APP_NAME,
user_id=req.user_id,
session_id=session_id,
)
user_content = types.Content(
role="user",
parts=[types.Part(text=req.message)],
)
final_response = None
try:
async for event in runner.run_async(
user_id=req.user_id,
session_id=session_id,
new_message=user_content,
):
try:
if getattr(event, "is_final_response", lambda: False)() and getattr(event, "content", None):
text = _text_from_content(event.content)
if text:
final_response = text
break
if getattr(event, "content", None):
text = _text_from_content(event.content)
if text:
final_response = text
except (AttributeError, TypeError, KeyError) as e:
logger.debug("Skipping event parse error: %s", e)
continue
except ValueError as e:
logger.warning("Runner run_async ValueError: %s", e)
raise HTTPException(status_code=400, detail=str(e))
if not final_response:
logger.warning("Chat: no response text from agent for user_id=%s", req.user_id)
raise HTTPException(
status_code=503,
detail="Agent did not return a response. Please try again or check server logs.",
)
return ChatResponse(response=final_response)
except HTTPException:
raise
except Exception as e:
logger.exception("Chat failed: %s", e)
raise HTTPException(
status_code=500,
detail="An error occurred while processing your request. Check server logs for details.",
)
# =========================
# HEALTH CHECK
# =========================
@app.get("/health")
def health():
return {"status": "ok"}