Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from google.adk.sessions import InMemorySessionService | |
| from google.adk.runners import Runner | |
| from google.genai import types | |
| from agent import root_agent | |
| import callback | |
| logger = logging.getLogger(__name__) | |
| # ========================= | |
| # FASTAPI APP | |
| # ========================= | |
| app = FastAPI( | |
| title="Leave Policy Assistant", | |
| description="AI agent to answer leave policy questions", | |
| version="1.0.0" | |
| ) | |
| # ========================= | |
| # ADK SESSION SETUP | |
| # ========================= | |
| APP_NAME = "leave_policy_app" | |
| session_service = InMemorySessionService() | |
| runner = Runner( | |
| agent=root_agent, | |
| app_name=APP_NAME, | |
| session_service=session_service, | |
| ) | |
| # Optional: persist ADK callback events to Snowflake (set ENABLE_AUDIT_SINK=1 in .env) | |
| if os.getenv("ENABLE_AUDIT_SINK", "").strip().lower() in ("1", "true", "yes"): | |
| try: | |
| from audit_db import SnowflakeAuditSink, ensure_table | |
| callback.set_audit_sink(SnowflakeAuditSink()) | |
| ensure_table() | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning("Audit sink not enabled: %s", e) | |
| # ========================= | |
| # REQUEST / RESPONSE MODELS | |
| # ========================= | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| message: str | |
| class ChatResponse(BaseModel): | |
| response: str | |
| # ========================= | |
| # CHAT ENDPOINT | |
| # ========================= | |
| def _text_from_content(content) -> str | None: | |
| """Extract response text from event content (check all parts).""" | |
| if content is None: | |
| return None | |
| try: | |
| parts = getattr(content, "parts", None) | |
| if not parts: | |
| return None | |
| for part in parts: | |
| text = getattr(part, "text", None) | |
| if text is not None and str(text).strip(): | |
| return str(text).strip() | |
| except (AttributeError, TypeError, IndexError): | |
| pass | |
| return None | |
| async def chat(req: ChatRequest): | |
| try: | |
| session_id = f"{req.user_id}_session" | |
| session = await session_service.get_session( | |
| app_name=APP_NAME, | |
| user_id=req.user_id, | |
| session_id=session_id, | |
| ) | |
| if not session: | |
| await session_service.create_session( | |
| app_name=APP_NAME, | |
| user_id=req.user_id, | |
| session_id=session_id, | |
| ) | |
| user_content = types.Content( | |
| role="user", | |
| parts=[types.Part(text=req.message)], | |
| ) | |
| final_response = None | |
| try: | |
| async for event in runner.run_async( | |
| user_id=req.user_id, | |
| session_id=session_id, | |
| new_message=user_content, | |
| ): | |
| try: | |
| if getattr(event, "is_final_response", lambda: False)() and getattr(event, "content", None): | |
| text = _text_from_content(event.content) | |
| if text: | |
| final_response = text | |
| break | |
| if getattr(event, "content", None): | |
| text = _text_from_content(event.content) | |
| if text: | |
| final_response = text | |
| except (AttributeError, TypeError, KeyError) as e: | |
| logger.debug("Skipping event parse error: %s", e) | |
| continue | |
| except ValueError as e: | |
| logger.warning("Runner run_async ValueError: %s", e) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| if not final_response: | |
| logger.warning("Chat: no response text from agent for user_id=%s", req.user_id) | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Agent did not return a response. Please try again or check server logs.", | |
| ) | |
| return ChatResponse(response=final_response) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Chat failed: %s", e) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="An error occurred while processing your request. Check server logs for details.", | |
| ) | |
| # ========================= | |
| # HEALTH CHECK | |
| # ========================= | |
| def health(): | |
| return {"status": "ok"} | |