vishalkatheriya commited on
Commit
68759d0
·
verified ·
1 Parent(s): 67e9ef8

Upload 5 files

Browse files
Files changed (5) hide show
  1. src/__init__.py +1 -0
  2. src/agent.py +57 -0
  3. src/main.py +99 -0
  4. src/prompt.py +33 -0
  5. src/sf_tools.py +158 -0
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import agent
src/agent.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from google.adk.agents.llm_agent import Agent
4
+ from google.adk.models.lite_llm import LiteLlm
5
+
6
+
7
+ import prompt as prmpt
8
+ import policy_tools as tls
9
+
10
+ # 🔑 Load .env file
11
+ load_dotenv()
12
+
13
+ groq_api_key = os.getenv("GROQ_API_KEY")
14
+ groq_model = os.getenv("MODEL")
15
+
16
+ if not groq_api_key:
17
+ raise ValueError("GROQ_API_KEY is not set in .env")
18
+
19
+ if not groq_model:
20
+ raise ValueError("MODEL is not set in .env")
21
+
22
+ # Optional (only needed if LiteLLM expects env var)
23
+ os.environ["GROQ_API_KEY"] = groq_api_key
24
+
25
+ # root_agent = Agent(
26
+ # model=LiteLlm(model=groq_model,
27
+ # custom_llm_provider="groq"
28
+ # ),
29
+ # name=prmpt.AGENT_CONFIG["name"],
30
+ # description=prmpt.AGENT_CONFIG["description"],
31
+ # instruction=prmpt.AGENT_CONFIG["instruction"],
32
+ # tools=[tls.check_leave_eligibility,tls.get_employee,tls.get_leave_policy],
33
+ # )
34
+
35
+ import sf_tools as tls # replace with actual Tool class in your agent
36
+
37
+ # Now create the agent
38
+ root_agent = Agent(
39
+ model=LiteLlm(
40
+ model=groq_model,
41
+ custom_llm_provider="groq"
42
+ ),
43
+ name=prmpt.AGENT_CONFIG["name"],
44
+ description=prmpt.AGENT_CONFIG["description"],
45
+ instruction=prmpt.AGENT_CONFIG["instruction"],
46
+ tools=[tls.check_leave_eligibility, tls.get_employee, tls.get_leave_policy],
47
+ )
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
src/main.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+
5
+ from google.adk.sessions import InMemorySessionService
6
+ from google.adk.runners import Runner
7
+ from google.genai import types
8
+
9
+ from agent import root_agent # your agent file
10
+
11
+ # =========================
12
+ # FASTAPI APP
13
+ # =========================
14
+ app = FastAPI(
15
+ title="Leave Policy Assistant",
16
+ description="AI agent to answer leave policy questions",
17
+ version="1.0.0"
18
+ )
19
+
20
+ # =========================
21
+ # ADK SESSION SETUP
22
+ # =========================
23
+ APP_NAME = "leave_policy_app"
24
+ session_service = InMemorySessionService()
25
+
26
+ runner = Runner(
27
+ agent=root_agent,
28
+ app_name=APP_NAME,
29
+ session_service=session_service,
30
+ )
31
+
32
+ # =========================
33
+ # REQUEST / RESPONSE MODELS
34
+ # =========================
35
+ class ChatRequest(BaseModel):
36
+ user_id: str
37
+ message: str
38
+
39
+
40
+ class ChatResponse(BaseModel):
41
+ response: str
42
+
43
+
44
+ # =========================
45
+ # CHAT ENDPOINT
46
+ # =========================
47
+ @app.post("/chat", response_model=ChatResponse)
48
+ async def chat(req: ChatRequest):
49
+ try:
50
+ # Generate session_id automatically per user
51
+ session_id = f"{req.user_id}_session"
52
+
53
+ # Check if session exists, create if not
54
+ session = await session_service.get_session(
55
+ app_name=APP_NAME,
56
+ user_id=req.user_id,
57
+ session_id=session_id,
58
+ )
59
+ if not session:
60
+ await session_service.create_session(
61
+ app_name=APP_NAME,
62
+ user_id=req.user_id,
63
+ session_id=session_id,
64
+ )
65
+
66
+ # Prepare the user message
67
+ user_content = types.Content(
68
+ role="user",
69
+ parts=[types.Part(text=req.message)],
70
+ )
71
+
72
+ final_response = None
73
+
74
+ # Run the agent asynchronously
75
+ async for event in runner.run_async(
76
+ user_id=req.user_id,
77
+ session_id=session_id,
78
+ new_message=user_content,
79
+ ):
80
+ if event.is_final_response():
81
+ if event.content and event.content.parts:
82
+ final_response = event.content.parts[0].text
83
+ break
84
+
85
+ if not final_response:
86
+ raise HTTPException(status_code=500, detail="No response from agent")
87
+
88
+ return ChatResponse(response=final_response)
89
+
90
+ except Exception as e:
91
+ raise HTTPException(status_code=500, detail=str(e))
92
+
93
+
94
+ # =========================
95
+ # HEALTH CHECK
96
+ # =========================
97
+ @app.get("/health")
98
+ def health():
99
+ return {"status": "ok"}
src/prompt.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AGENT_CONFIG = {
2
+ "name": "vishal",
3
+ "description": "A Leave Policy Assistant Agent built using LiteLLM",
4
+ "instruction": """
5
+ You are Vishal, an intelligent and reliable HR Leave Policy Assistant.
6
+
7
+ Your primary role is to help employees understand company leave policies and
8
+ check their leave eligibility accurately.
9
+
10
+ What you should do:
11
+ - Answer questions related to leave policies (PTO, Sick Leave, Casual Leave, etc.).
12
+ - Check leave eligibility using the available tools.
13
+ - Explain rules such as allowances, carryover limits, notice periods, and eligibility criteria.
14
+ - Provide clear, concise, and employee-friendly responses.
15
+
16
+ How to behave:
17
+ - Always use tools when employee data or leave policy data is required.
18
+ - Never guess or fabricate employee information.
19
+ - If the employee ID is missing, politely ask for it.
20
+ - If the leave type is unknown, suggest valid leave types for the employee’s country.
21
+ - If the employee is inactive, clearly state that leave is not applicable.
22
+ - Handle invalid inputs (wrong dates, negative days, unknown leave types) gracefully.
23
+
24
+ Conversation handling:
25
+ - Maintain context across multiple messages.
26
+ - Ask follow-up questions when required information is missing.
27
+ - Be professional, calm, and helpful at all times.
28
+
29
+ Important:
30
+ - Accuracy is more important than speed.
31
+ - If something cannot be determined, explain why instead of guessing.
32
+ """
33
+ }
src/sf_tools.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from snowflake.connector import connect # type: ignore[import-untyped]
2
+ from datetime import datetime
3
+ import os
4
+ import json
5
+ from dotenv import load_dotenv
6
+
7
+ # 🔑 Load .env file
8
+ load_dotenv()
9
+ # -----------------------------
10
+ # Snowflake connection config
11
+ # -----------------------------
12
+
13
+ def get_snowflake_connection():
14
+ conn = connect(
15
+ user=os.getenv("SNOWFLAKE_USER"),
16
+ password=os.getenv("SNOWFLAKE_PASSWORD"),
17
+ account=os.getenv("SNOWFLAKE_ACCOUNT"),
18
+ role=os.getenv("SNOWFLAKE_ROLE"),
19
+ warehouse=os.getenv("SNOWFLAKE_WAREHOUSE"),
20
+ database=os.getenv("SNOWFLAKE_DATABASE"),
21
+ schema=os.getenv("SNOWFLAKE_SCHEMA")
22
+ )
23
+ return conn
24
+
25
+ # -----------------------------
26
+ # Fetch Employee Details
27
+ # -----------------------------
28
+ def get_employee(employee_id: str):
29
+ """Fetch employee details from Snowflake."""
30
+ print("get_employee called")
31
+ conn = get_snowflake_connection()
32
+ cur = conn.cursor()
33
+ try:
34
+ cur.execute("""
35
+ SELECT employee_id, country, employment_status
36
+ FROM employees
37
+ WHERE employee_id = %s
38
+ """, (employee_id,))
39
+ row = cur.fetchone()
40
+ if not row:
41
+ return {"error": f"Employee '{employee_id}' not found"}
42
+
43
+ employee = {
44
+ "employee_id": row[0],
45
+ "country": row[1],
46
+ "employment_status": row[2]
47
+ }
48
+
49
+ if employee["employment_status"] != "active":
50
+ return {"error": f"Employee '{employee_id}' is not active"}
51
+
52
+ # Fetch leave balances
53
+ cur.execute("""
54
+ SELECT leave_type, balance
55
+ FROM leave_balances
56
+ WHERE employee_id = %s
57
+ """, (employee_id,))
58
+ balances = {r[0]: r[1] for r in cur.fetchall()}
59
+ employee["leave_balances"] = balances
60
+
61
+ return employee
62
+ finally:
63
+ cur.close()
64
+ conn.close()
65
+
66
+
67
+ # -----------------------------
68
+ # Fetch Leave Policy
69
+ # -----------------------------
70
+ def get_leave_policy(country: str, leave_type: str):
71
+ """Fetch leave policy from Snowflake."""
72
+ print("get_leave_policy called")
73
+ conn = get_snowflake_connection()
74
+ cur = conn.cursor()
75
+ try:
76
+ cur.execute("""
77
+ SELECT policy
78
+ FROM leave_policies
79
+ WHERE country = %s AND leave_type = %s
80
+ """, (country, leave_type))
81
+ row = cur.fetchone()
82
+ if not row:
83
+ # fetch available leave types
84
+ cur.execute("""
85
+ SELECT leave_type
86
+ FROM leave_policies
87
+ WHERE country = %s
88
+ """, (country,))
89
+ leave_types = [r[0] for r in cur.fetchall()]
90
+ return {
91
+ "error": f"Leave type '{leave_type}' not found",
92
+ "available_leave_types": leave_types
93
+ }
94
+ # policy is VARIANT JSON in Snowflake (may be dict or JSON string)
95
+ raw = row[0]
96
+ if isinstance(raw, dict):
97
+ return raw
98
+ if isinstance(raw, str):
99
+ return json.loads(raw)
100
+ return dict(raw)
101
+ finally:
102
+ cur.close()
103
+ conn.close()
104
+
105
+
106
+ # -----------------------------
107
+ # Check Leave Eligibility
108
+ # -----------------------------
109
+ def check_leave_eligibility(employee_id: str, leave_type: str, requested_days: int):
110
+ """Check if employee is eligible to take leave."""
111
+ print("check_leave_eligibility called")
112
+ if requested_days <= 0:
113
+ return {"eligible": False, "reason": "Requested leave days must be > 0"}
114
+
115
+ employee = get_employee(employee_id)
116
+ if "error" in employee:
117
+ return employee
118
+
119
+ country = employee["country"]
120
+ policy = get_leave_policy(country, leave_type)
121
+ if "error" in policy:
122
+ return policy
123
+
124
+ balance = employee["leave_balances"].get(leave_type, 0)
125
+
126
+ # Max consecutive days
127
+ if "max_consecutive_days" in policy:
128
+ if requested_days > policy["max_consecutive_days"]:
129
+ return {
130
+ "eligible": False,
131
+ "reason": f"Maximum consecutive days allowed for {leave_type} is {policy['max_consecutive_days']}"
132
+ }
133
+
134
+ # Annual allowance
135
+ if "annual_allowance" in policy:
136
+ if requested_days > policy["annual_allowance"]:
137
+ return {
138
+ "eligible": False,
139
+ "reason": "Requested days exceed annual allowance"
140
+ }
141
+
142
+ # Available balance
143
+ if requested_days > balance:
144
+ return {
145
+ "eligible": False,
146
+ "reason": "Insufficient leave balance",
147
+ "available_balance": balance
148
+ }
149
+
150
+ return {
151
+ "eligible": True,
152
+ "employee_id": employee_id,
153
+ "leave_type": leave_type,
154
+ "approved_days": requested_days,
155
+ "remaining_balance": balance - requested_days
156
+ }
157
+
158
+