Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- src/__init__.py +1 -0
- src/agent.py +57 -0
- src/main.py +99 -0
- src/prompt.py +33 -0
- src/sf_tools.py +158 -0
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import agent
|
src/agent.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from google.adk.agents.llm_agent import Agent
|
| 4 |
+
from google.adk.models.lite_llm import LiteLlm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import prompt as prmpt
|
| 8 |
+
import policy_tools as tls
|
| 9 |
+
|
| 10 |
+
# 🔑 Load .env file
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 14 |
+
groq_model = os.getenv("MODEL")
|
| 15 |
+
|
| 16 |
+
if not groq_api_key:
|
| 17 |
+
raise ValueError("GROQ_API_KEY is not set in .env")
|
| 18 |
+
|
| 19 |
+
if not groq_model:
|
| 20 |
+
raise ValueError("MODEL is not set in .env")
|
| 21 |
+
|
| 22 |
+
# Optional (only needed if LiteLLM expects env var)
|
| 23 |
+
os.environ["GROQ_API_KEY"] = groq_api_key
|
| 24 |
+
|
| 25 |
+
# root_agent = Agent(
|
| 26 |
+
# model=LiteLlm(model=groq_model,
|
| 27 |
+
# custom_llm_provider="groq"
|
| 28 |
+
# ),
|
| 29 |
+
# name=prmpt.AGENT_CONFIG["name"],
|
| 30 |
+
# description=prmpt.AGENT_CONFIG["description"],
|
| 31 |
+
# instruction=prmpt.AGENT_CONFIG["instruction"],
|
| 32 |
+
# tools=[tls.check_leave_eligibility,tls.get_employee,tls.get_leave_policy],
|
| 33 |
+
# )
|
| 34 |
+
|
| 35 |
+
import sf_tools as tls # replace with actual Tool class in your agent
|
| 36 |
+
|
| 37 |
+
# Now create the agent
|
| 38 |
+
root_agent = Agent(
|
| 39 |
+
model=LiteLlm(
|
| 40 |
+
model=groq_model,
|
| 41 |
+
custom_llm_provider="groq"
|
| 42 |
+
),
|
| 43 |
+
name=prmpt.AGENT_CONFIG["name"],
|
| 44 |
+
description=prmpt.AGENT_CONFIG["description"],
|
| 45 |
+
instruction=prmpt.AGENT_CONFIG["instruction"],
|
| 46 |
+
tools=[tls.check_leave_eligibility, tls.get_employee, tls.get_leave_policy],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
src/main.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from google.adk.sessions import InMemorySessionService
|
| 6 |
+
from google.adk.runners import Runner
|
| 7 |
+
from google.genai import types
|
| 8 |
+
|
| 9 |
+
from agent import root_agent # your agent file
|
| 10 |
+
|
| 11 |
+
# =========================
|
| 12 |
+
# FASTAPI APP
|
| 13 |
+
# =========================
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="Leave Policy Assistant",
|
| 16 |
+
description="AI agent to answer leave policy questions",
|
| 17 |
+
version="1.0.0"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# =========================
|
| 21 |
+
# ADK SESSION SETUP
|
| 22 |
+
# =========================
|
| 23 |
+
APP_NAME = "leave_policy_app"
|
| 24 |
+
session_service = InMemorySessionService()
|
| 25 |
+
|
| 26 |
+
runner = Runner(
|
| 27 |
+
agent=root_agent,
|
| 28 |
+
app_name=APP_NAME,
|
| 29 |
+
session_service=session_service,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# =========================
|
| 33 |
+
# REQUEST / RESPONSE MODELS
|
| 34 |
+
# =========================
|
| 35 |
+
class ChatRequest(BaseModel):
|
| 36 |
+
user_id: str
|
| 37 |
+
message: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ChatResponse(BaseModel):
|
| 41 |
+
response: str
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# =========================
|
| 45 |
+
# CHAT ENDPOINT
|
| 46 |
+
# =========================
|
| 47 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 48 |
+
async def chat(req: ChatRequest):
|
| 49 |
+
try:
|
| 50 |
+
# Generate session_id automatically per user
|
| 51 |
+
session_id = f"{req.user_id}_session"
|
| 52 |
+
|
| 53 |
+
# Check if session exists, create if not
|
| 54 |
+
session = await session_service.get_session(
|
| 55 |
+
app_name=APP_NAME,
|
| 56 |
+
user_id=req.user_id,
|
| 57 |
+
session_id=session_id,
|
| 58 |
+
)
|
| 59 |
+
if not session:
|
| 60 |
+
await session_service.create_session(
|
| 61 |
+
app_name=APP_NAME,
|
| 62 |
+
user_id=req.user_id,
|
| 63 |
+
session_id=session_id,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Prepare the user message
|
| 67 |
+
user_content = types.Content(
|
| 68 |
+
role="user",
|
| 69 |
+
parts=[types.Part(text=req.message)],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
final_response = None
|
| 73 |
+
|
| 74 |
+
# Run the agent asynchronously
|
| 75 |
+
async for event in runner.run_async(
|
| 76 |
+
user_id=req.user_id,
|
| 77 |
+
session_id=session_id,
|
| 78 |
+
new_message=user_content,
|
| 79 |
+
):
|
| 80 |
+
if event.is_final_response():
|
| 81 |
+
if event.content and event.content.parts:
|
| 82 |
+
final_response = event.content.parts[0].text
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
if not final_response:
|
| 86 |
+
raise HTTPException(status_code=500, detail="No response from agent")
|
| 87 |
+
|
| 88 |
+
return ChatResponse(response=final_response)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# =========================
|
| 95 |
+
# HEALTH CHECK
|
| 96 |
+
# =========================
|
| 97 |
+
@app.get("/health")
|
| 98 |
+
def health():
|
| 99 |
+
return {"status": "ok"}
|
src/prompt.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AGENT_CONFIG = {
|
| 2 |
+
"name": "vishal",
|
| 3 |
+
"description": "A Leave Policy Assistant Agent built using LiteLLM",
|
| 4 |
+
"instruction": """
|
| 5 |
+
You are Vishal, an intelligent and reliable HR Leave Policy Assistant.
|
| 6 |
+
|
| 7 |
+
Your primary role is to help employees understand company leave policies and
|
| 8 |
+
check their leave eligibility accurately.
|
| 9 |
+
|
| 10 |
+
What you should do:
|
| 11 |
+
- Answer questions related to leave policies (PTO, Sick Leave, Casual Leave, etc.).
|
| 12 |
+
- Check leave eligibility using the available tools.
|
| 13 |
+
- Explain rules such as allowances, carryover limits, notice periods, and eligibility criteria.
|
| 14 |
+
- Provide clear, concise, and employee-friendly responses.
|
| 15 |
+
|
| 16 |
+
How to behave:
|
| 17 |
+
- Always use tools when employee data or leave policy data is required.
|
| 18 |
+
- Never guess or fabricate employee information.
|
| 19 |
+
- If the employee ID is missing, politely ask for it.
|
| 20 |
+
- If the leave type is unknown, suggest valid leave types for the employee’s country.
|
| 21 |
+
- If the employee is inactive, clearly state that leave is not applicable.
|
| 22 |
+
- Handle invalid inputs (wrong dates, negative days, unknown leave types) gracefully.
|
| 23 |
+
|
| 24 |
+
Conversation handling:
|
| 25 |
+
- Maintain context across multiple messages.
|
| 26 |
+
- Ask follow-up questions when required information is missing.
|
| 27 |
+
- Be professional, calm, and helpful at all times.
|
| 28 |
+
|
| 29 |
+
Important:
|
| 30 |
+
- Accuracy is more important than speed.
|
| 31 |
+
- If something cannot be determined, explain why instead of guessing.
|
| 32 |
+
"""
|
| 33 |
+
}
|
src/sf_tools.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from snowflake.connector import connect # type: ignore[import-untyped]
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# 🔑 Load .env file
|
| 8 |
+
load_dotenv()
|
| 9 |
+
# -----------------------------
|
| 10 |
+
# Snowflake connection config
|
| 11 |
+
# -----------------------------
|
| 12 |
+
|
| 13 |
+
def get_snowflake_connection():
|
| 14 |
+
conn = connect(
|
| 15 |
+
user=os.getenv("SNOWFLAKE_USER"),
|
| 16 |
+
password=os.getenv("SNOWFLAKE_PASSWORD"),
|
| 17 |
+
account=os.getenv("SNOWFLAKE_ACCOUNT"),
|
| 18 |
+
role=os.getenv("SNOWFLAKE_ROLE"),
|
| 19 |
+
warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
|
| 20 |
+
database=os.getenv("SNOWFLAKE_DATABASE"),
|
| 21 |
+
schema=os.getenv("SNOWFLAKE_SCHEMA")
|
| 22 |
+
)
|
| 23 |
+
return conn
|
| 24 |
+
|
| 25 |
+
# -----------------------------
|
| 26 |
+
# Fetch Employee Details
|
| 27 |
+
# -----------------------------
|
| 28 |
+
def get_employee(employee_id: str):
|
| 29 |
+
"""Fetch employee details from Snowflake."""
|
| 30 |
+
print("get_employee called")
|
| 31 |
+
conn = get_snowflake_connection()
|
| 32 |
+
cur = conn.cursor()
|
| 33 |
+
try:
|
| 34 |
+
cur.execute("""
|
| 35 |
+
SELECT employee_id, country, employment_status
|
| 36 |
+
FROM employees
|
| 37 |
+
WHERE employee_id = %s
|
| 38 |
+
""", (employee_id,))
|
| 39 |
+
row = cur.fetchone()
|
| 40 |
+
if not row:
|
| 41 |
+
return {"error": f"Employee '{employee_id}' not found"}
|
| 42 |
+
|
| 43 |
+
employee = {
|
| 44 |
+
"employee_id": row[0],
|
| 45 |
+
"country": row[1],
|
| 46 |
+
"employment_status": row[2]
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
if employee["employment_status"] != "active":
|
| 50 |
+
return {"error": f"Employee '{employee_id}' is not active"}
|
| 51 |
+
|
| 52 |
+
# Fetch leave balances
|
| 53 |
+
cur.execute("""
|
| 54 |
+
SELECT leave_type, balance
|
| 55 |
+
FROM leave_balances
|
| 56 |
+
WHERE employee_id = %s
|
| 57 |
+
""", (employee_id,))
|
| 58 |
+
balances = {r[0]: r[1] for r in cur.fetchall()}
|
| 59 |
+
employee["leave_balances"] = balances
|
| 60 |
+
|
| 61 |
+
return employee
|
| 62 |
+
finally:
|
| 63 |
+
cur.close()
|
| 64 |
+
conn.close()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# -----------------------------
|
| 68 |
+
# Fetch Leave Policy
|
| 69 |
+
# -----------------------------
|
| 70 |
+
def get_leave_policy(country: str, leave_type: str):
|
| 71 |
+
"""Fetch leave policy from Snowflake."""
|
| 72 |
+
print("get_leave_policy called")
|
| 73 |
+
conn = get_snowflake_connection()
|
| 74 |
+
cur = conn.cursor()
|
| 75 |
+
try:
|
| 76 |
+
cur.execute("""
|
| 77 |
+
SELECT policy
|
| 78 |
+
FROM leave_policies
|
| 79 |
+
WHERE country = %s AND leave_type = %s
|
| 80 |
+
""", (country, leave_type))
|
| 81 |
+
row = cur.fetchone()
|
| 82 |
+
if not row:
|
| 83 |
+
# fetch available leave types
|
| 84 |
+
cur.execute("""
|
| 85 |
+
SELECT leave_type
|
| 86 |
+
FROM leave_policies
|
| 87 |
+
WHERE country = %s
|
| 88 |
+
""", (country,))
|
| 89 |
+
leave_types = [r[0] for r in cur.fetchall()]
|
| 90 |
+
return {
|
| 91 |
+
"error": f"Leave type '{leave_type}' not found",
|
| 92 |
+
"available_leave_types": leave_types
|
| 93 |
+
}
|
| 94 |
+
# policy is VARIANT JSON in Snowflake (may be dict or JSON string)
|
| 95 |
+
raw = row[0]
|
| 96 |
+
if isinstance(raw, dict):
|
| 97 |
+
return raw
|
| 98 |
+
if isinstance(raw, str):
|
| 99 |
+
return json.loads(raw)
|
| 100 |
+
return dict(raw)
|
| 101 |
+
finally:
|
| 102 |
+
cur.close()
|
| 103 |
+
conn.close()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# -----------------------------
|
| 107 |
+
# Check Leave Eligibility
|
| 108 |
+
# -----------------------------
|
| 109 |
+
def check_leave_eligibility(employee_id: str, leave_type: str, requested_days: int):
|
| 110 |
+
"""Check if employee is eligible to take leave."""
|
| 111 |
+
print("check_leave_eligibility called")
|
| 112 |
+
if requested_days <= 0:
|
| 113 |
+
return {"eligible": False, "reason": "Requested leave days must be > 0"}
|
| 114 |
+
|
| 115 |
+
employee = get_employee(employee_id)
|
| 116 |
+
if "error" in employee:
|
| 117 |
+
return employee
|
| 118 |
+
|
| 119 |
+
country = employee["country"]
|
| 120 |
+
policy = get_leave_policy(country, leave_type)
|
| 121 |
+
if "error" in policy:
|
| 122 |
+
return policy
|
| 123 |
+
|
| 124 |
+
balance = employee["leave_balances"].get(leave_type, 0)
|
| 125 |
+
|
| 126 |
+
# Max consecutive days
|
| 127 |
+
if "max_consecutive_days" in policy:
|
| 128 |
+
if requested_days > policy["max_consecutive_days"]:
|
| 129 |
+
return {
|
| 130 |
+
"eligible": False,
|
| 131 |
+
"reason": f"Maximum consecutive days allowed for {leave_type} is {policy['max_consecutive_days']}"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# Annual allowance
|
| 135 |
+
if "annual_allowance" in policy:
|
| 136 |
+
if requested_days > policy["annual_allowance"]:
|
| 137 |
+
return {
|
| 138 |
+
"eligible": False,
|
| 139 |
+
"reason": "Requested days exceed annual allowance"
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# Available balance
|
| 143 |
+
if requested_days > balance:
|
| 144 |
+
return {
|
| 145 |
+
"eligible": False,
|
| 146 |
+
"reason": "Insufficient leave balance",
|
| 147 |
+
"available_balance": balance
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"eligible": True,
|
| 152 |
+
"employee_id": employee_id,
|
| 153 |
+
"leave_type": leave_type,
|
| 154 |
+
"approved_days": requested_days,
|
| 155 |
+
"remaining_balance": balance - requested_days
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|