import logging import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from google.adk.sessions import InMemorySessionService from google.adk.runners import Runner from google.genai import types from agent import root_agent import callback logger = logging.getLogger(__name__) # ========================= # FASTAPI APP # ========================= app = FastAPI( title="Leave Policy Assistant", description="AI agent to answer leave policy questions", version="1.0.0" ) # ========================= # ADK SESSION SETUP # ========================= APP_NAME = "leave_policy_app" session_service = InMemorySessionService() runner = Runner( agent=root_agent, app_name=APP_NAME, session_service=session_service, ) # Optional: persist ADK callback events to Snowflake (set ENABLE_AUDIT_SINK=1 in .env) if os.getenv("ENABLE_AUDIT_SINK", "").strip().lower() in ("1", "true", "yes"): try: from audit_db import SnowflakeAuditSink, ensure_table callback.set_audit_sink(SnowflakeAuditSink()) ensure_table() except Exception as e: import logging logging.getLogger(__name__).warning("Audit sink not enabled: %s", e) # ========================= # REQUEST / RESPONSE MODELS # ========================= class ChatRequest(BaseModel): user_id: str message: str class ChatResponse(BaseModel): response: str # ========================= # CHAT ENDPOINT # ========================= def _text_from_content(content) -> str | None: """Extract response text from event content (check all parts).""" if content is None: return None try: parts = getattr(content, "parts", None) if not parts: return None for part in parts: text = getattr(part, "text", None) if text is not None and str(text).strip(): return str(text).strip() except (AttributeError, TypeError, IndexError): pass return None @app.post("/chat", response_model=ChatResponse) async def chat(req: ChatRequest): try: session_id = f"{req.user_id}_session" session = await session_service.get_session( app_name=APP_NAME, user_id=req.user_id, session_id=session_id, ) if not session: await session_service.create_session( app_name=APP_NAME, user_id=req.user_id, session_id=session_id, ) user_content = types.Content( role="user", parts=[types.Part(text=req.message)], ) final_response = None try: async for event in runner.run_async( user_id=req.user_id, session_id=session_id, new_message=user_content, ): try: if getattr(event, "is_final_response", lambda: False)() and getattr(event, "content", None): text = _text_from_content(event.content) if text: final_response = text break if getattr(event, "content", None): text = _text_from_content(event.content) if text: final_response = text except (AttributeError, TypeError, KeyError) as e: logger.debug("Skipping event parse error: %s", e) continue except ValueError as e: logger.warning("Runner run_async ValueError: %s", e) raise HTTPException(status_code=400, detail=str(e)) if not final_response: logger.warning("Chat: no response text from agent for user_id=%s", req.user_id) raise HTTPException( status_code=503, detail="Agent did not return a response. Please try again or check server logs.", ) return ChatResponse(response=final_response) except HTTPException: raise except Exception as e: logger.exception("Chat failed: %s", e) raise HTTPException( status_code=500, detail="An error occurred while processing your request. Check server logs for details.", ) # ========================= # HEALTH CHECK # ========================= @app.get("/health") def health(): return {"status": "ok"}